import logging
from typing import Dict, Tuple

import networkx as nx
import numpy as np

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_postprocess import HailoPostprocess
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerActivationClippingConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import ActivationClippingMode
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import ActivationClippingError
from hailo_model_optimization.algorithms.equiv_matching.matching_algo import MatchingAlgo
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm
from hailo_model_optimization.algorithms.stats_collection.stats_collection import StatsCollector


class ClipActivationStats(OptimizationAlgorithm):
    """
    This class is responsible for clipping the activation statistic of layers in the model

    There are 2 types of clipping:
    - Manual: clip the range to a given value
    - Percentile: Clip the range based on approximated histogram

    The clipping can be applied in 2 modes:
    - Applied directly to the existing stats
    - Recollect the stats of the model with the new clipped range (with internal alls flag)
    """

    clip_by_layer: Dict[str, Tuple[float, float]]

    def __init__(self, model, model_config, logger_level, dataset, logger=None):
        super().__init__(model, model_config, name="Clip Statistics", logger_level=logger_level, logger=logger)
        self._dataset = dataset
        self._matching_algo = MatchingAlgo(model, model_config, logging.DEBUG)
        self._should_recollect = []

    def _setup(self):
        retval = super()._setup()
        self._matching_algo._setup()
        self.clip_by_layer = dict()
        self.clipped_source_layers = dict()
        return retval

    def should_skip_algo(self):
        return False

    def get_algo_config(self):
        return self._model_config.activation_clipping

    def log_config(self):
        pass

    @property
    def equiv_match(self):
        return self._matching_algo.equiv_match

    def _run_int(self):
        for matching_component_group in self.equiv_match.get_groups_components():
            self._find_range_in_components_group(matching_component_group)

        if len(self.clipped_source_layers) == 0:
            return

        if all(self._should_recollect):
            self._recollect_stats()
        elif not any(self._should_recollect):
            self._apply_new_range_to_stats()
        else:
            raise ValueError("Activation clipping commands had inconsistent recollect_stats value")

    def _apply_new_range_to_stats(self):
        """Apply the desired clipped stats on the existing model stats"""
        for source_layer, new_range in self.clipped_source_layers.items():
            if isinstance(self._model.layers[source_layer], HailoPostprocess):
                continue
            self._model.layers[source_layer].keep_original_output_stats()
            out_stats = self._model.layers[source_layer].get_output_stats()[0]
            new_min, new_max = new_range
            out_stats.min[...] = np.maximum(new_min, out_stats.min)
            out_stats.max[...] = np.minimum(new_max, out_stats.max)

    def _find_range_in_components_group(self, component_group):
        """Update the desired range of all source layers in component based on the configuration and collected stats"""
        source_layers = self.equiv_match.source_layers_group(component_group)
        toposorted_nodes = self.equiv_match.get_sorted_u_nodes_in_componenets_group(component_group)
        cfg = self.get_algo_config()
        default_cfg = LayerActivationClippingConfig.get_default()
        # Find and update stats of source layers
        for lname in source_layers:
            cfg_lname = self.get_layer_name_in_config(self._model.layers[lname])
            cfg.layers.setdefault(cfg_lname, default_cfg)
            layer_clip_cfg = cfg.layers[cfg_lname]
            if layer_clip_cfg is not None and layer_clip_cfg.mode != ActivationClippingMode.disabled:
                self._find_range(lname, [lname], layer_clip_cfg)
                self._should_recollect.append(layer_clip_cfg.recollect_stats)

        # Find and update stats of "transparent" layers in the component
        for lname in toposorted_nodes:
            cfg_lname = self.get_layer_name_in_config(self._model.layers[lname])
            cfg.layers.setdefault(cfg_lname, default_cfg)
            layer_clip_cfg = cfg.layers[cfg_lname]
            if lname in source_layers:
                continue
            if layer_clip_cfg is not None and layer_clip_cfg.mode != ActivationClippingMode.disabled:
                if self._model.layers[lname].activation_atomic_op is None:
                    fix_cmd = f"pre_quantization_optimization(activation_clipping, layers=[{lname}], mode=disabled)"
                    raise ActivationClippingError(
                        f"Can't apply activation clipping to layer without activation - {lname}. "
                        f"Please use the following command:\n\t{fix_cmd}",
                    )
                self._find_range(lname, source_layers, layer_clip_cfg)
                self._should_recollect.append(layer_clip_cfg.recollect_stats)

    def _find_range(self, target_layer, source_layers, layer_clip_cfg):
        """
        Update the clipped range of source layers based on single layer in the component

        `_find_range_in_components_group` updated the entire component,
        this function updates a single layer in the component
        """
        new_min, new_max = self.resolve_range_from_cfg(self._model.layers[target_layer], layer_clip_cfg)
        self.clip_by_layer[target_layer] = (np.min(new_min), np.max(new_max))
        source_layers = self.find_connected_source_layers(target_layer, source_layers)
        for source_layer in source_layers:
            (current_min, current_max) = self.clipped_source_layers.get(source_layer, (-np.inf, np.inf))
            # out_stats = self._model.layers[source_layer].get_output_stats()[0]
            current_min = np.maximum(new_min, current_min, dtype=np.float32)
            current_max = np.minimum(new_max, current_max, dtype=np.float32)
            self.clipped_source_layers[source_layer] = (current_min, current_max)

    def resolve_range_from_cfg(self, layer: BaseHailoLayer, layer_clip_cfg: LayerActivationClippingConfig):
        """Get desired range from config - e.g. resolve the percentile statistics (passthru if manual)"""
        if isinstance(layer, HailoPostprocess):
            return -np.inf, np.inf
        stats = layer.get_output_stats()[0]
        if layer_clip_cfg.mode in {ActivationClippingMode.percentile}:
            hist_bins = np.linspace(min(stats.min), max(stats.max), len(stats.dynamic_histogram), dtype=np.float32)
            percentile_values = np.array(layer_clip_cfg.clipping_values) / 100
            min_value = self._percentile_from_hist(stats.dynamic_histogram, hist_bins, percentile_values[0])
            max_value = self._percentile_from_hist(stats.dynamic_histogram, hist_bins, percentile_values[1])
            clip_values = min_value, max_value
        elif layer_clip_cfg.mode == ActivationClippingMode.manual:
            clip_values = layer_clip_cfg.clipping_values
            if clip_values[0] < np.min(stats.min) or clip_values[1] > np.max(stats.max):
                self._logger.warning(
                    f"Layer {layer.name} configured with manual range of [{clip_values[0]}, {clip_values[1]}]. "
                    f"The actual stats were [{stats.min, stats.max}] so the range won't be clipped. "
                    f"Please use force_range command instead if you would like to expand the range."
                )
        else:
            raise ValueError(f"Unexpected activation clipping mode {layer_clip_cfg.mode}")
        new_max = np.minimum(stats.max, clip_values[1], dtype=stats.max.dtype)
        new_min = np.maximum(stats.min, clip_values[0], dtype=stats.min.dtype)
        return new_min, new_max

    def _recollect_stats(self):
        """
        Rerun the stats collector after the clipped range has been applied to the layers
        EXPERIMENTAL
        This feature is not fully ready, and it has some assumptions:
        - the activations stats are collected in native mode
        - the source layers has activation op (which is a general assumption in the activation clipping)
        - updates _clip_range of the activation_op directly, which is used only in call_native (atm)
        """
        for lname, clip_range in self.clipped_source_layers.items():
            layer = self._model.layers[lname]
            act_op = layer.activation_atomic_op
            if act_op is None:
                raise ValueError(f"Layer {lname} don't have activation op, can't apply activation clipping")
            act_op._clip_range = clip_range
        stats_collector = StatsCollector(
            self._model,
            self._model_config,
            logging.INFO,
            self._dataset,
            logger=self._logger,
        )
        stats_collector.run()

        for lname in self.clipped_source_layers:
            layer = self._model.layers[lname]
            act_op = layer.activation_atomic_op
            if act_op is None:
                raise ValueError(f"Layer {lname} don't have activation op, can't apply activation clipping")
            act_op._clip_range = None

    @staticmethod
    def _percentile_from_hist(hist_values, hist_bins, percentile):
        hist_cumsum = np.cumsum(hist_values)
        percentile_index_in_values_list = percentile * (hist_cumsum[-1] - 1)
        bin_idx = np.where(hist_cumsum - 1 >= np.floor(percentile_index_in_values_list))[0][0]
        if hist_cumsum[bin_idx] - 1 > percentile_index_in_values_list:
            return hist_bins[bin_idx]
        next_non_empty_bin_idx = bin_idx + 1
        while next_non_empty_bin_idx < len(hist_bins) and hist_values[next_non_empty_bin_idx] == 0:
            next_non_empty_bin_idx += 1
        if next_non_empty_bin_idx == len(hist_bins):
            return hist_bins[bin_idx]
        return hist_bins[bin_idx] + (percentile_index_in_values_list % 1) * (
            hist_bins[next_non_empty_bin_idx] - hist_bins[bin_idx]
        )

    def find_connected_source_layers(self, target_layer, source_layers):
        """
        Finds the all the source in a component connected to a target layer

        Checks which source layers has path to the target layer
        """
        connected_sources = []
        for source_layer in source_layers:
            if nx.has_path(self._model.flow, source_layer, target_layer):
                connected_sources.append(source_layer)
        return connected_sources

    def export_statistics(self):
        stats = {}
        for layer in self.clip_by_layer:
            stats[f"{layer}/clip_values"] = self.clip_by_layer[layer]

        return stats

    def finalize_global_cfg(self, algo_config):
        pass

    def _get_valid_layer_cfg(self, lname, cfg):
        return cfg
