import re
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Dict, List, Set

import tensorflow as tf

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationLayerConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_base import CommandMeta
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerConfigBaseModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_model import FeatureConfigBaseModel
from hailo_model_optimization.acceleras.utils.acceleras_definitions import TrackerStage
from hailo_model_optimization.acceleras.utils.dataset_util import get_dataset_length
from hailo_model_optimization.algorithms.algorithm_base import AlgorithmBase


class OptimizationAlgorithm(AlgorithmBase, ABC):
    """
    Base class for all the algorithms in Model Optimization
    This Gives an interface for all the Mo blocks
    and a generic use case.

    Args:
        model: Mutable, the algorithm may change the model
        model_config: dict - Params needed for the block
        name: the name of the algorithm
        logger: the logger we use if needed
    Example:
        >>> model = HailoModel()
        >>> model_config = {}
        >>> mo_algo = OptimizationAlgorithm(model, model_config)
        >>> mo_algo.run()

    """

    @staticmethod
    def get_layer_name_in_config(acceleras_layer):
        # TODO: remove this function & usages.
        lname = acceleras_layer.full_name
        return lname

    def _get_layer_cfg_from_layer_name_in_model(self, lname) -> ModelOptimizationLayerConfig:
        acceleras_layer = self._model.layers[lname]
        lname_in_cfg = self.get_layer_name_in_config(acceleras_layer)
        layer_cfg = self._model_config.layers[lname_in_cfg]
        return layer_cfg

    @property
    def optimization_target(self):
        return self._model.optimization_target

    def run(self):
        self.finalize_config()
        return super().run()

    def finalize_config(self):
        """
        Finalize configuration for algo. (Should happen before should_skip_algo logic)
        The mo script parser doesn't apply any verification of "unfolding" of the config,
        therefore each algorithm has to verify the validity of the config and unfold
        the glob syntax and add scope to layer names.
        The basic flow goes as follows:
        1. finalize_layer_cfg - fix the individual layers config of the respective
                algorithm. (add scope, unfold glob, and resolve conflict)
            a. _filter_empty_layer_cfg - filters out "empty" config. required is some algorithm for simplification
            b. _get_valid_layer_cfg - filters out invalid config
        2. finalize_global_cfg - fill missing field (if any), verify validity of the algorithm's config
        3. finalize_flat_layers_fields - unfold glob syntax in algorithm where the layers are a
                simple argument (such as qft freeze layers)
        """
        config = self.get_algo_config()
        if "layers" in config.nested_keys():
            layer_config = self.finalize_layer_cfg(config.layers)
            config.layers = layer_config
        self.finalize_global_cfg(config)
        self.finalize_flat_layers_fields(config)

    @abstractmethod
    def finalize_global_cfg(self, algo_config):
        """
        Finalize the algorithm's config. (values that are not layer specific)
        Can include values verification, fetching data from the other algo's config, etc...
        """

    def finalize_layer_cfg(self, layers_cfg_dict: Dict[str, LayerConfigBaseModel]):
        """
        Finalize a layer config dict. expand glob syntax and add scope to implicit layer names.
        """
        final_cfg = dict()
        if len(layers_cfg_dict) == 0:
            return final_cfg
        # TODO: glob is lower priority?
        for alls_lname, cfg in layers_cfg_dict.items():
            cfg = cfg.dict()
            cfg = self._filter_empty_layer_cfg(cfg)

            is_glob_lname = alls_lname.startswith("{") and alls_lname.endswith("}")
            if is_glob_lname:
                layers = self._model.get_layer_names_from_glob(alls_lname[1:-1])
                is_wildchar_glob = "*" in alls_lname[1:-1]
            else:
                lname = self._model.get_layer_name_with_scope(alls_lname)
                self._validate_layer_config(lname, cfg)
                layers = [lname]
                is_wildchar_glob = False

            if is_wildchar_glob:
                cfg_per_layer = {lname: self._get_valid_layer_cfg(lname, deepcopy(cfg)) for lname in layers}
            else:
                cfg_per_layer = {lname: cfg for lname in layers}

            for layer, layer_cfg in cfg_per_layer.items():
                final_cfg = self._meta_aware_deep_update(final_cfg, layer, layer_cfg, is_wildchar_glob)

        layer_cfg_class = type(next(iter(layers_cfg_dict.values())))
        final_cfg = {lname: layer_cfg_class(**cfg) for lname, cfg in final_cfg.items()}
        return final_cfg

    def _validate_layer_config(self, lname, cfg):
        """
        Check if given layer config is valid.

        if invalid config has explicit meta info, raise exception with command.
        if invalid config comes from "default" values, log to debug logger

        Not an optimal check and might be a bit buggy in some cases
        """
        legal_cfg = self._get_valid_layer_cfg(lname, deepcopy(cfg))
        # This check isn't optimal, because it is a bit too harsh.
        if legal_cfg != cfg:
            invalid_keys = list(cfg.keys() - legal_cfg.keys())
            invalid_values = [cfg[key] for key in invalid_keys]
            if cfg["meta"] is None or len(cfg["meta"].keys() & set(invalid_keys)) == 0:
                if isinstance(self._model.layers[lname], BaseHailoLayer):
                    self._logger.debug(
                        f"Invalid values were loaded as default {lname} ({invalid_keys} - {invalid_values})",
                    )
            else:
                raise ValueError(
                    f"Unsupported value {invalid_values} for fields {invalid_keys} in layer {lname}.\n"
                    f"Original command was {cfg['meta'][invalid_keys[0]].command}",
                )

    @staticmethod
    def _meta_aware_deep_update(current_cfg: Dict[str, dict], layer, cfg: dict, is_wildchar_glob: bool):
        """
        Update layer config with considation to the meta info.

        This code might have some bugs in edge cases. It hasn't been checked thoroughly for all the algorithms.
        """
        if layer not in current_cfg:
            current_cfg[layer] = cfg
            return current_cfg

        new_meta_info: Dict[str, CommandMeta] = cfg.get("meta", dict())
        if new_meta_info is None:
            new_meta_info = dict()

        existing_meta_info: Dict[str, CommandMeta] = current_cfg[layer].get("meta", dict())
        if existing_meta_info is None:
            existing_meta_info = dict()

        keys_to_update: Set[str] = cfg.keys() - {"meta"}
        if new_meta_info:
            keys_to_update &= new_meta_info.keys()

        for key in keys_to_update:
            new_field_meta_info = new_meta_info.get(key, None)
            existing_field_meta_info = existing_meta_info.get(key, None)
            if (existing_field_meta_info is None) or (
                new_field_meta_info is not None and new_field_meta_info.line > existing_field_meta_info.line
            ):
                # If field has no meta, update. (maybe prioritize non glob?)
                # If new field line has higher line number, update
                current_cfg[layer][key] = cfg[key]
                if new_field_meta_info is not None:
                    new_meta = CommandMeta(new_field_meta_info.line, new_field_meta_info.command, is_wildchar_glob)
                    meta_cfg = current_cfg[layer].get("meta", dict())
                    if meta_cfg is None:
                        meta_cfg = dict()
                    meta_cfg[key] = new_meta
                    current_cfg[layer]["meta"] = meta_cfg
                continue
        return current_cfg

    def _get_valid_layer_cfg(self, lname, cfg):
        raise NotImplementedError(f"Feature {self._name} did not fully implement layer logic")

    def check_dataset_length(
        self,
        algo_config: FeatureConfigBaseModel,
        dataset_size_key: str,
        dataset: tf.data.Dataset,
        warning_if_larger: bool = False,
    ):
        """
        Verify the dataset has enough data samples for the given algorithm to run.

        Args:
            algo_config: subclass of FeatureConfigBaseModel for the respective algorithm
            dataset_size_key: the key that will be fetch from the algo config (usually calibset_size / dataset_size)
            dataset: tf.data.Dataset with the algorithm's dataset
            warning_if_larger: if true, a suggestion to utilize more data samples will be shown.

        """
        dataset_size_from_cfg = getattr(algo_config, dataset_size_key)
        dataset_length = get_dataset_length(dataset, threshold=dataset_size_from_cfg + int(warning_if_larger))
        is_default_dataset_size = not algo_config.has_meta_info(dataset_size_key)

        if dataset_length < dataset_size_from_cfg:
            if is_default_dataset_size:
                self._logger.warning(
                    f"{self._name}:"
                    f"\tDataset didn't have enough data for {dataset_size_key} of {dataset_size_from_cfg} "
                    f"\tQuantizing using calibration size of {dataset_length}",
                )
                setattr(algo_config, dataset_size_key, int(dataset_length))
            else:
                raise ValueError(
                    f"Required {dataset_size_key} is {dataset_size_from_cfg}, "
                    f"but dataset contained only {dataset_length} images",
                )
        elif dataset_length > dataset_size_from_cfg and warning_if_larger:
            self._logger.warning(
                f"Dataset is larger than {dataset_size_key} in {self._name}. "
                "Increasing the algorithm dataset size might improve the results",
            )

    def check_batch_size(self, algo_config: FeatureConfigBaseModel, dataset_size_key: str, batch_size_key: str):
        """
        Common function for batch size finalization.
        Compare the batch size and dataset size.
        Raises an error if explicit batch size is smaller than dataset,
        otherwise, reduces the batch size
        """
        batch_size = getattr(algo_config, batch_size_key)
        dataset_length = getattr(algo_config, dataset_size_key)
        if dataset_length < batch_size:
            is_default_batch_size = not algo_config.has_meta_info(batch_size_key)
            if is_default_batch_size:
                self._logger.warning(
                    f"{self._name}:"
                    f"\tBatch size was greater than {dataset_size_key}, using {batch_size_key} {dataset_length}",
                )
                setattr(algo_config, batch_size_key, dataset_length)
            else:
                raise ValueError(
                    f"Required {batch_size_key} for {self._name} is {batch_size}, "
                    f"but dataset contained only {dataset_length} images",
                )

    def finalize_flat_layers_fields(self, algo_config: FeatureConfigBaseModel):
        """
        Finalize all list of layers in the config (lists without per-layer config)
        Resolve glob syantx, add scope to name, etc...
        Expand additional values if dependant on the glob length.
        """
        flat_layers_fields = algo_config.flat_layers_fields()
        for key in flat_layers_fields:
            is_wildchar_glob = False
            new_layers = []
            values = getattr(algo_config, key)
            if values is None:
                continue
            indices_offsets = {}
            for index, alls_lname in enumerate(values):
                is_glob_lname = alls_lname.startswith("{") and alls_lname.endswith("}")
                if is_glob_lname:
                    layers = self._model.get_layer_names_from_glob(alls_lname[1:-1])
                    is_wildchar_glob = "*" in alls_lname[1:-1]
                    indices_offsets = self._expand_glob_dependant_values(
                        algo_config,
                        layers,
                        flat_layers_fields[key],
                        indices_offsets,
                        index,
                    )
                else:
                    layers = [self._model.get_layer_name_with_scope(alls_lname)]
                new_layers.extend(layers)
            setattr(algo_config, key, new_layers)
            if algo_config.has_meta_info(key):
                meta_info = algo_config.meta[key]
                new_meta = CommandMeta(meta_info.line, meta_info.command, is_wildchar_glob)
                algo_config.meta[key] = new_meta

    def _expand_glob_dependant_values(
        self,
        algo_config: FeatureConfigBaseModel,
        glob_layers: List[str],
        dependants_keys: List[str],
        indices_offsets: Dict[str, int],
        command_index: int,
    ):
        """
        Update the values of the keys taht depend on the glob layers.
        e.g. if the glob was expanded to 3 values,
        the dependent values has to be duplicated from 1 to 3 in the equivalent indices
        """
        for dep_key in dependants_keys:
            indices_offsets.setdefault(dep_key, 0)
            curr_offset = indices_offsets[dep_key]
            dep_values = getattr(algo_config, dep_key)
            if dep_values is None:
                continue
            repeated_value = dep_values.pop(command_index + curr_offset)
            for _ in range(len(glob_layers)):
                dep_values.insert(command_index + curr_offset, repeated_value)
            indices_offsets[dep_key] += len(glob_layers) - 1
            setattr(algo_config, dep_key, dep_values)
        return indices_offsets

    def _filter_empty_layer_cfg(self, cfg):
        return cfg

    @staticmethod
    def get_block_and_layer_names(layer_name: str):
        # the block name is a part of the layer name, separated by "__"
        # for example: "block3__conv4"
        search_block = re.search(r"^block(\d+)__", layer_name)
        block_name = search_block.group(0) if search_block else ""
        layer_name = re.sub(block_name, "", layer_name)
        return block_name, layer_name

    def get_modifications_meta_data(self):
        self._modifications_meta_data.set_stage(TrackerStage.QUANTIZE)
        return self._modifications_meta_data
