"""
This module will contaion a base algorithm for block level optimization.
It should include storage tools, inference tools, and block detection heuristic
"""

import logging
import os
import re
import shutil
import tempfile
from abc import ABC, abstractmethod
from copy import deepcopy
from enum import Enum
from functools import partial
from typing import Callable, Dict, Optional, Tuple

import psutil
import tensorflow as tf
from tqdm import tqdm

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model.hailo_model.model_flow import ModelFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasValueError
from hailo_model_optimization.algorithms.algorithm_base import AlgoResults
from hailo_model_optimization.algorithms.block_by_block.cache_utils import (
    clean_cache,
    dataset_from_cache,
    dataset_to_cache,
    get_layer_cache_dir,
    get_max_cache_size,
    save_batch,
)
from hailo_model_optimization.algorithms.dali_utils.dataset_util import tf_unpad_input
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class SetLimvalsMode(Enum):
    set = "set"
    reset = "reset"


class BlockByBlock(OptimizationAlgorithm, ABC):
    """
    Base class for block separation of the model.
    """

    def __init__(
        self,
        model: HailoModel,
        fp_model: HailoModel,
        model_config: ModelOptimizationConfig,
        name: str,
        dataset: tf.data.Dataset,
        *args,
        is_disable_internal_encoding: bool = False,
        work_dir: Optional[str] = None,
        logger_level=logging.DEBUG,
        logger=None,
        **kwargs,
    ):
        super().__init__(model, model_config, name, logger_level, logger, *args, **kwargs)
        self._dataset = dataset
        self._fp_model = fp_model

        work_dir = work_dir if work_dir is not None else os.getcwd()
        os.makedirs(work_dir, exist_ok=True)
        self._work_dir = tempfile.mkdtemp(dir=work_dir, prefix=snake_case(f"{self._name} cache_"))
        self._base_cache_dir = None

        self._clip_range_bck = {}
        self._is_disable_internal_encoding = is_disable_internal_encoding

    def create_cache_dir(self, prefix):
        """
        Create new cache directory in the base_cache_dir folder
        """
        self._base_cache_dir = tempfile.mkdtemp(dir=self._work_dir, prefix=prefix)
        return self._base_cache_dir

    def clean_cache_dir(self):
        shutil.rmtree(self._work_dir, ignore_errors=True)

    def _prepare_original_dataset_as_cache(self, dataset_size, cache_dir, batch_size, compress):
        """
        Save the original dataset to the given cache directory
        """
        dataset = self._dataset.map(prep_label_less)
        dataset = dataset.map(self._model.inputs_as_dict)
        dataset = dataset.take(dataset_size)
        if self._model.preproc_cb is not None:
            dataset = dataset.batch(batch_size).map(self._model.preproc_cb).unbatch()
        return dataset_to_cache(
            dataset,
            self._model.layers,
            cache_dir,
            batch_size,
            compress,
        )

    @abstractmethod
    def get_blocks(self) -> Dict[str, ModelFlow]:
        """
        Separate the model into list of blocks
        """

    @abstractmethod
    def get_dataset_size(self) -> int:
        """
        Get the dataset size used by the algorithm (for block iteration)
        """

    @abstractmethod
    def get_batch_size(self) -> int:
        """
        Get the batch size used by the algorithm (for block iteration)
        """

    @abstractmethod
    def is_compressed_cache(self) -> bool:
        """
        Indicated whether the interlayer results cache should be compressed or not
        """

    def infer_block_with_cache(
        self,
        current_block: HailoModel,
        all_blocks: Dict[str, ModelFlow],
        interlayer_results: Dict[str, str],
        cache_dir: str,
        count: int,
    ):
        """
        Infer a block. Reads the input data from a cache and writes the results to the cache
        """
        cache_dir_per_layer = dict()
        output_layers = current_block.flow.output_nodes
        for output_layer in output_layers:
            cache_dir_per_layer[output_layer] = get_layer_cache_dir(cache_dir, output_layer)

        compress = self.is_compressed_cache()
        batch_handler = partial(self._save_batch, current_block, cache_dir_per_layer, compress)

        self._infer_block(
            current_block,
            interlayer_results,
            count,
            self.get_batch_size(),
            batch_handler,
        )
        interlayer_results.update(cache_dir_per_layer)
        clean_cache(all_blocks, interlayer_results, output_layers)

    @staticmethod
    def _save_batch(
        current_block: HailoModel,
        cache_dir_per_layer: Dict[str, str],
        compress: bool,
        result: tf.Tensor,
        result_index: int,
    ):
        """
        Save single batch result to the relevant cache
        """
        output_layers = current_block.flow.output_nodes
        if len(output_layers) == 1 and current_block.layers[output_layers[0]].num_outputs == 1:
            result = [result]
        out_ind = 0
        for output_layer in output_layers:
            # TODO: what is the behavior when output_nodes > 1 & a single layer has multiple outputs
            num_outputs = current_block.layers[output_layer].num_outputs
            save_batch(
                current_block.layers[output_layer],
                result[out_ind : out_ind + num_outputs],
                cache_dir_per_layer[output_layer],
                result_index,
                compress,
            )
            out_ind += num_outputs

    def _setup_cache(self, cache_name: str) -> Tuple[str, Dict[str, str]]:
        """
        Create cache dir and copy the input dataset to the cache
        Returns:
            Cache dir, and dict for interlayer results
        """
        cache_dir = self.create_cache_dir(cache_name)
        interlayer_results = self._prepare_original_dataset_as_cache(
            self.get_dataset_size(),
            cache_dir,
            self.get_batch_size(),
            self.is_compressed_cache(),
        )
        return cache_dir, interlayer_results

    @staticmethod
    def _infer_block(
        block_model: HailoModel,
        interlayer_results: Dict[str, str],
        count: int,
        batch_size: int,
        batch_handler: Callable[[tf.Tensor, int], None],
    ):
        """
        Infers a block of the model and saves the results to a cache dir.

        Returns
            The cache dir path

        """
        data, unpadded_shape = dataset_from_cache(block_model.flow.input_nodes, interlayer_results, count)

        @tf.function
        def call_block(inputs: tf.Tensor):
            inputs = tf_unpad_input(inputs, unpadded_shape)
            result = block_model(inputs)
            return result

        for batch_index, sample in enumerate(data.batch(batch_size)):
            # TODO: add data slicing for varying shape
            result = call_block(sample)
            retval = batch_handler(result, batch_index * batch_size)
        return retval

    def _comperative_run(
        self,
        pre_quant_cb: Callable = None,
        preproc_cb: Callable = None,
        postproc_cb: Callable = None,
    ):
        """
        Callable inputs:
            [HailoModel, List[ModelFlow], Dict[str, str], Dict[str, str]]
        """
        blocks = self.get_blocks()
        native_cache_dir, interlayer_results_native = self._setup_cache("native_cache")
        quant_cache_dir, interlayer_results_quant = self._setup_cache("quant_cache")

        pbar = tqdm(total=len(blocks), unit="blocks", desc=self._name)
        while blocks:
            current_block = blocks.pop(0)
            missing_nodes = (
                current_block.nodes
                - self._fp_model.layers
                - set(current_block.output_nodes)
                - set(current_block.input_nodes)
            )
            if missing_nodes:
                fp_block = deepcopy(current_block)
                for node in missing_nodes:
                    fp_block.remove_layer(node)
            else:
                fp_block = current_block
            fp_block_model = self._fp_model.get_sub_model(fp_block)
            fp_block_model.set_native()

            self._set_limvals(fp_block_model, mode=SetLimvalsMode.set)
            pbar.set_postfix({"Layers": current_block.output_nodes})  # Is there better identifier?
            dataset_size = self.get_dataset_size()

            if preproc_cb is not None:
                preproc_cb(fp_block_model, interlayer_results_native, interlayer_results_quant)
            # Infer native for reference
            self.infer_block_with_cache(
                fp_block_model,
                blocks,
                interlayer_results_native,
                native_cache_dir,
                dataset_size,
            )

            # Startig the quantized section for the block
            block_model = self._model.get_sub_model(current_block)
            block_model.set_lossy(native_act=True)
            if self._is_disable_internal_encoding:
                block_model.disable_internal_encoding()

            # Optimize the block
            if pre_quant_cb is not None:
                pre_quant_cb(block_model, interlayer_results_native, interlayer_results_quant)

            # Prepare for inference. Without this, inference for mobilenets might drop by about
            # 10 percents.
            if self._is_disable_internal_encoding:
                block_model.enable_internal_encoding()
            # Infer Lossy for future inputs
            self.infer_block_with_cache(block_model, blocks, interlayer_results_quant, quant_cache_dir, dataset_size)

            if postproc_cb is not None:
                postproc_cb(block_model, interlayer_results_native, interlayer_results_quant)
            self._set_limvals(fp_block_model, mode=SetLimvalsMode.reset)
            pbar.update(1)
        pbar.close()

    def _set_limvals(self, fp_block_model, mode=SetLimvalsMode.set):
        for lname in fp_block_model.layers:
            if lname not in self._model.layers:
                continue
            layer = self._model.layers[lname]
            if isinstance(layer, BaseHailoNonNNCoreLayer):
                continue
            if layer.activation_atomic_op is None:
                continue
            if mode == SetLimvalsMode.set:
                limvals = layer.get_output_limvals()[0]
                self._clip_range_bck[layer.full_name] = deepcopy(
                    fp_block_model.layers[lname].activation_atomic_op._clip_range,
                )
                fp_block_model.layers[lname].activation_atomic_op._clip_range = limvals
            elif mode == SetLimvalsMode.reset:
                fp_block_model.layers[lname].activation_atomic_op._clip_range = self._clip_range_bck[layer.full_name]
            else:
                raise AccelerasValueError("Illegal mode value to _set_limvals")

    def check_storage_usage(self, dali_cache=False):
        """
        Check if the current work directory has sufficient storage for running the current algorithm
        """
        compression = self.is_compressed_cache()
        max_cache_size = get_max_cache_size(self._model, self.get_blocks(), dali_cache)
        max_disk_usage = max_cache_size * 4 * self.get_dataset_size()
        free_disk = psutil.disk_usage(self._work_dir).free
        be_verb = "might" if compression else "will"
        self._logger.info(
            f"The algorithm {self._name} {be_verb} use up to {max_disk_usage / (10 ** 9):.02f} GB of storage space",
        )
        if max_disk_usage > free_disk:
            free_disk = free_disk / (2**30)
            be_verb = "might be" if compression else "is"
            msg = (
                f"Current free disk space {free_disk:0.02f} GB {be_verb} insufficient."
                f"Consider reducing the dataset size if you encounter space issues."
            )
            msg = (f"{msg} Additionally, you can enable cache compression.") if compression else msg
            if compression:
                self._logger.warning(msg)
            else:
                self._logger.error(msg)
                raise OSError(msg)

    def run(self) -> Tuple[HailoModel, AlgoResults]:
        try:
            retval = super().run()
        finally:
            self.clean_cache_dir()
        return retval


def prep_label_less(images, image_info):
    """Map function for dataset which removes the image info"""
    # TODO: The train label is off by one, it needs to be fixed
    return images


def snake_case(s):
    return "_".join(re.sub("([A-Z][a-z]+)", r" \1", re.sub("([A-Z]+)", r" \1", s.replace("-", " "))).split()).lower()
