"""
This module will contaion a base algorithm for block level optimization.
It should include storage tools, inference tools, and block detection heuristic
"""

import logging
import os
import re
import shutil
import tempfile
from abc import ABC, abstractmethod
from copy import deepcopy
from enum import Enum
from typing import Callable, Dict, Optional, Tuple

import psutil
import tensorflow as tf
from tqdm import tqdm

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model.hailo_model.model_flow import ModelFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasValueError
from hailo_model_optimization.algorithms.algorithm_base import AlgoResults
from hailo_model_optimization.algorithms.block_by_block.cache_utils import clean_cache, get_max_cache_size
from hailo_model_optimization.algorithms.dali_utils.data_feeder_tfrecord import DataFeederTFRecord

# from hailo_model_optimization.algorithms.dali_utils.dataset_util import tf_unpad_input
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class SetLimvalsMode(Enum):
    set = "set"
    reset = "reset"


class BlockByBlock(OptimizationAlgorithm, ABC):
    """
    Base class for block separation of the model.
    """

    def __init__(
        self,
        model: HailoModel,
        fp_model: HailoModel,
        model_config: ModelOptimizationConfig,
        name: str,
        dataset: tf.data.Dataset,
        *args,
        work_dir: Optional[str] = None,
        logger_level=logging.DEBUG,
        logger=None,
        **kwargs,
    ):
        super().__init__(model, model_config, name, logger_level, logger, *args, **kwargs)
        self._dataset = dataset
        self._fp_model = fp_model

        work_dir = work_dir if work_dir is not None else os.getcwd()
        os.makedirs(work_dir, exist_ok=True)
        self._work_dir = tempfile.mkdtemp(dir=work_dir, prefix=snake_case(f"{self._name} cache_"))
        self._base_cache_dir = None

        self._clip_range_bck = {}
        self.data_feeder: DataFeederTFRecord = None

    @abstractmethod
    def get_blocks(self) -> Dict[str, ModelFlow]:
        """
        Separate the model into list of blocks
        """

    @abstractmethod
    def get_dataset_size(self) -> int:
        """
        Get the dataset size used by the algorithm (for block iteration)
        """

    @abstractmethod
    def get_batch_size(self) -> int:
        """
        Get the batch size used by the algorithm (for block iteration)
        """

    @abstractmethod
    def get_compression_type(self) -> str:
        """
        Compression type of block activations when cached to disk.
        """

    def get_internal_encoding(self) -> bool:
        """
        Enable or disable internal encodings between ops and layers.
        """

    def get_eager_mode(self) -> bool:
        """
        Enable or disable eager mode.
        """

    def create_cache_dir(self, prefix):
        """
        Create new cache directory in the base_cache_dir folder
        """
        self._base_cache_dir = tempfile.mkdtemp(dir=self._work_dir, prefix=prefix)
        return self._base_cache_dir

    def clean_cache_dir(self):
        shutil.rmtree(self._work_dir, ignore_errors=True)

    def _prepare_original_dataset_as_cache(self, dataset_size, batch_size) -> tf.data.Dataset:
        """
        Save the original dataset to the given cache directory
        """

        def _aux_fun(dataset: tf.data.Dataset) -> tf.data.Dataset:
            """
            An auxiliary function for preparing a dataset. This function:
                1. Removes the label from the dataset.
                2. Packs the dataset(s) into a dictionary.
                3. Takes the first dataset_size samples from the dataset.

            Args:
                dataset (tf.data.Dataset): the input dataset.

            Returns:
                dataset (tf.data.Dataset): the processed dataset.

            """
            AUTOTUNE = tf.data.AUTOTUNE
            dataset = self._dataset.map(prep_label_less, num_parallel_calls=AUTOTUNE)
            dataset = dataset.map(self._model.inputs_as_dict, num_parallel_calls=AUTOTUNE)
            dataset = dataset.take(dataset_size)
            return dataset

        dataset = self._dataset.apply(_aux_fun)
        dataset = dataset.cache().prefetch(tf.data.AUTOTUNE)
        dataset = dataset.batch(batch_size)

        if self._model.preproc_cb is not None:
            dataset = dataset.map(self._model.preproc_cb, num_parallel_calls=tf.data.AUTOTUNE)
        return dataset

    @staticmethod
    def get_build_inputs(dataset):
        """
        Get the inputs for the model build function.
        """
        dataset_sample = dataset if isinstance(dataset, dict) else next(iter(dataset))
        if isinstance(dataset_sample, dict):
            # If the dataset is a dict, we need to extract the shape of each input
            build_inputs = {k: [1, *v.shape[1:]] for k, v in dataset_sample.items()}
        else:
            build_inputs = tuple([1, *dataset_sample.shape[1:]])

        return build_inputs

    def infer_block_with_cache(
        self,
        block_model: HailoModel,
        all_blocks: Dict[str, ModelFlow],
        interlayer_results: Dict[str, str],
        cache_dir: str,
        count: int,
    ):
        """
        Infer a block. Reads the input data from a cache and writes the results to the cache
        """
        # Read data from cache dir to tensorflow dataset
        dataset = self.data_feeder.cache_to_dataset(
            block_model.flow.input_nodes,
            interlayer_results,
            batch_size=self.get_batch_size(),
        )
        dataset = dataset.cache()
        dataset = dataset.prefetch(tf.data.AUTOTUNE)

        # Infer for each data sample in dataset.
        if not block_model.built:
            block_model.build(self.get_build_inputs(dataset))
        dataset_results = dataset.map(block_model, num_parallel_calls=1, deterministic=False)

        # Assign layer keys (element_spec) for the dataset
        output_layers = block_model.flow.output_nodes

        def _map_output_layer_name(*args):
            return {lname: args[i] for i, lname in enumerate(output_layers)}

        dataset_results = dataset_results.map(
            _map_output_layer_name,
            num_parallel_calls=tf.data.AUTOTUNE,
            deterministic=True,
        )
        # Save dataset to cache dir
        new_cache_dir_per_layer = self.data_feeder.dataset_to_cache(dataset_results, cache_dir)
        interlayer_results.update(new_cache_dir_per_layer)
        clean_cache(all_blocks, interlayer_results, output_layers)

    def _setup_cache(self, cache_name: str) -> Tuple[str, Dict[str, str]]:
        """
        Create cache dir and copy the input dataset to the cache
        Returns:
            Cache dir, and dict for interlayer results
        """
        cache_dir = self.create_cache_dir(cache_name)
        dataset_size = self.get_dataset_size()
        batch_size = self.get_batch_size()

        dataset = self._prepare_original_dataset_as_cache(dataset_size, batch_size)
        interlayer_results = self.data_feeder.dataset_to_cache(dataset, cache_dir)
        return cache_dir, interlayer_results

    def _comparative_run(
        self,
        pre_quant_cb: Callable = None,
        preproc_cb: Callable = None,
        postproc_cb: Callable = None,
    ):
        """
        Callable inputs:
            [HailoModel, List[ModelFlow], Dict[str, str], Dict[str, str]]
        """
        blocks = self.get_blocks()
        native_cache_dir, interlayer_results_native = self._setup_cache("native_cache")
        quant_cache_dir, interlayer_results_quant = self._setup_cache("quant_cache")
        pbar = tqdm(total=len(blocks), unit="blocks", desc=self._name)
        while blocks:
            block_flow = blocks.pop(0)
            missing_nodes = (
                block_flow.nodes - self._fp_model.layers - set(block_flow.output_nodes) - set(block_flow.input_nodes)
            )
            if missing_nodes:
                fp_block_flow = deepcopy(block_flow)
                for node in missing_nodes:
                    fp_block_flow.remove_layer(node)
            else:
                fp_block_flow = block_flow
            fp_block_model = self._fp_model.get_sub_model(fp_block_flow)
            fp_block_model.set_native()

            self._set_limvals(fp_block_model, mode=SetLimvalsMode.set)
            pbar.set_postfix({"Layers": block_flow.output_nodes})  # Is there better identifier?
            dataset_size = self.get_dataset_size()

            # Pre-processing native model
            if preproc_cb is not None:
                preproc_cb(fp_block_model, interlayer_results_native, interlayer_results_quant)
            # Infer native for reference
            self.infer_block_with_cache(
                fp_block_model,
                blocks,
                interlayer_results_native,
                native_cache_dir,
                dataset_size,
            )

            # Startig the quantized section for the block
            block_model = self._model.get_sub_model(block_flow)
            block_model.set_lossy(native_act=True)
            if not self.get_internal_encoding():
                block_model.disable_internal_encoding()

            # Optimize the block
            if pre_quant_cb is not None:
                pre_quant_cb(block_model, interlayer_results_native, interlayer_results_quant)

            # Prepare for inference. Without this, inference for mobilenets might drop by about
            # 10 percents.
            if not self.get_internal_encoding():
                block_model.enable_internal_encoding()
            # Infer Lossy for future inputs
            self.infer_block_with_cache(block_model, blocks, interlayer_results_quant, quant_cache_dir, dataset_size)

            if postproc_cb is not None:
                postproc_cb(block_model, interlayer_results_native, interlayer_results_quant)

            self._set_limvals(fp_block_model, mode=SetLimvalsMode.reset)
            pbar.update(1)
        pbar.close()

    def _set_limvals(self, fp_block_model, mode=SetLimvalsMode.set):
        for lname in fp_block_model.layers:
            if lname not in self._model.layers:
                continue
            layer = self._model.layers[lname]
            if isinstance(layer, BaseHailoNonNNCoreLayer):
                continue
            if layer.activation_atomic_op is None:
                continue
            if mode == SetLimvalsMode.set:
                limvals = layer.get_output_limvals()[0]
                self._clip_range_bck[layer.full_name] = deepcopy(
                    fp_block_model.layers[lname].activation_atomic_op._clip_range,
                )
                fp_block_model.layers[lname].activation_atomic_op._clip_range = limvals
            elif mode == SetLimvalsMode.reset:
                fp_block_model.layers[lname].activation_atomic_op._clip_range = self._clip_range_bck[layer.full_name]
            else:
                raise AccelerasValueError("Illegal mode value to _set_limvals")

    def check_storage_usage(self, dali_cache=False):
        """
        Check if the current work directory has sufficient storage for running the current algorithm.
        # factor 3 if dali, else 2 (technically)
        """
        compression_type = self.get_compression_type()
        max_cache_size = get_max_cache_size(self._model, self.get_blocks(), dali_cache)
        max_disk_usage = max_cache_size * 4 * self.get_dataset_size()
        free_disk = psutil.disk_usage(self._work_dir).free
        be_verb = "might" if compression_type else "will"
        self._logger.info(
            f"The algorithm {self._name} {be_verb} use up to {max_disk_usage / (10 ** 9):.02f} GB of storage space",
        )
        if max_disk_usage > free_disk:
            free_disk = free_disk / (2**30)
            be_verb = "might be" if compression_type else "is"
            msg = (
                f"Current free disk space {free_disk:0.02f} GB {be_verb} insufficient."
                f"Consider reducing the dataset size if you encounter space issues."
            )
            msg = (f"{msg} Additionally, you can enable cache compression.") if compression_type else msg
            if compression_type:
                self._logger.warning(msg)
            else:
                self._logger.error(msg)
                raise OSError(msg)

    def run(self) -> Tuple[HailoModel, AlgoResults]:
        try:
            retval = super().run()
        finally:
            self.clean_cache_dir()
        return retval


def prep_label_less(images, image_info):
    """Map function for dataset which removes the image info"""
    # TODO: The train label is off by one, it needs to be fixed
    return images


def snake_case(s):
    return "_".join(re.sub("([A-Z][a-z]+)", r" \1", re.sub("([A-Z]+)", r" \1", s.replace("-", " "))).split()).lower()
