#!/usr/bin/env python

"""Small tool to handle necessary operations when clipping activations prior to model calibration."""

import numpy as np

from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerActivationClippingConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_model import ActivationClippingConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import ActivationClippingMode
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.model_params.model_params import get_param_key


class ActivationClipping:
    def __init__(self, params, hn, clip_cfg: ActivationClippingConfig, logger=None):
        self._logger = logger or default_logger()
        self._params = params
        self._model_name = hn.name
        self.percentile_clipping_layers = []
        self._hailo_model = hn
        self.act_clip_params = {}
        self._clip_cfg = clip_cfg
        self._update_params()

    def _update_params(self):
        param_name = {
            ActivationClippingMode.manual: "activation_clipping_values",
            ActivationClippingMode.percentile: "activation_clipping_percentile_values",
        }

        for hn_layer in self._hailo_model.stable_toposort():
            clipping_config = self._clip_cfg.layers.get(hn_layer.name, LayerActivationClippingConfig.get_default())
            if clipping_config.mode == ActivationClippingMode.disabled:
                continue
            key = get_param_key(hn_layer.name, param_name[clipping_config.mode])

            self.act_clip_params[key] = clipping_config.clipping_values
            if clipping_config.mode in {ActivationClippingMode.percentile}:
                self.percentile_clipping_layers.append(hn_layer.name)
        self._params.update(self.act_clip_params)

    @property
    def has_clipping(self):
        return len(self.act_clip_params) > 0

    @property
    def has_percentile_clipping(self):
        return len(self.percentile_clipping_layers) > 0

    def add_histograms_params(self, stats):
        for layer in self.percentile_clipping_layers:
            min_stat = stats[get_param_key(layer, "stats_min_out")]
            max_stat = stats[get_param_key(layer, "stats_max_out")]
            self._params[get_param_key(layer, "activation_clipping_hist_range")] = [min_stat, max_stat]
            self._params[get_param_key(layer, "activation_clipping_hist_nbins")] = 1000
        self._params.update(self._params.params)
        return self._params

    def remove_unsupported_layers(self, supported_layers):
        layers_to_remove = set()
        for layer in self.percentile_clipping_layers:
            if layer not in supported_layers:
                layers_to_remove.add(layer)
                self._logger.debug(f"Skipping activation clipping for {layer} (isn't supported)")
        new_layers_to_clip = list(set(self.percentile_clipping_layers) - layers_to_remove)
        new_layers_to_clip.sort(key=lambda x: self.percentile_clipping_layers.index(x))
        self.percentile_clipping_layers = new_layers_to_clip

    @staticmethod
    def _np_like_percentile_from_hist(cumsum_hist, hist_bins, hist_values, percentile):
        percentile_index_in_values_list = percentile / 100 * (np.sum(hist_values) - 1)
        bin_idx = np.where(cumsum_hist - 1 >= np.floor(percentile_index_in_values_list))[0][0]
        if cumsum_hist[bin_idx] - 1 > percentile_index_in_values_list:
            return hist_bins[bin_idx]
        next_non_empty_bin_idx = bin_idx + 1
        while next_non_empty_bin_idx < len(hist_bins) and hist_values[next_non_empty_bin_idx] == 0:
            next_non_empty_bin_idx += 1
        if next_non_empty_bin_idx == len(hist_bins):
            return hist_bins[bin_idx]
        return hist_bins[bin_idx] + (percentile_index_in_values_list % 1) * (
            hist_bins[next_non_empty_bin_idx] - hist_bins[bin_idx]
        )

    def _calculate_clipping_vals_from_activation_histogram(self, activations_histograms_stats, layer):
        hist_values = activations_histograms_stats[f"{layer}/activation_hist:0"]
        hist_range = self._params[get_param_key(layer, "activation_clipping_hist_range")]
        hist_bins = np.linspace(hist_range[0], hist_range[1], 1000, dtype=np.float32)
        percentiles = self._params[get_param_key(layer, "activation_clipping_percentile_values")]
        cumsum_hist = np.cumsum(hist_values)

        low_value = self._np_like_percentile_from_hist(cumsum_hist, hist_bins, hist_values, percentiles[0])
        high_value = self._np_like_percentile_from_hist(cumsum_hist, hist_bins, hist_values, percentiles[1])

        return np.float32(low_value), np.float32(high_value)

    def add_clipping_params(self, histograms):
        for layer in self.percentile_clipping_layers:
            clip_vals = self._calculate_clipping_vals_from_activation_histogram(histograms, layer)
            default_logger().debug(f"For layer {layer} the clipping are {clip_vals}")
            self._params[get_param_key(layer, "activation_clipping_values")] = clip_vals
        self._params.update(self._params.params)
        return self._params

    @property
    def params(self):
        return self._params
