import numpy as np

from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerWeightsClippingConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_model import WeightsClippingConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import WeightsClippingMode
from hailo_sdk_client.quantization.tools.optimize_kernel_ranges import mmse
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.model_params.model_params import get_param_key


class WeightsClipping:
    def __init__(self, params, hn, clip_cfg: WeightsClippingConfig, logger=None):
        self._logger = logger or default_logger()
        self._params = params
        self._model_name = hn.name
        self.layer_to_clip = {}
        self._hailo_model = hn
        self._clip_cfg = clip_cfg
        self._extract_values()

    @property
    def has_clipping(self):
        return len(self.layer_to_clip) > 0

    def _extract_values(self):
        for hn_layer in self._hailo_model.stable_toposort():
            mode = self._get_real_mode(hn_layer)
            if mode == WeightsClippingMode.disabled:
                continue
            key = get_param_key(hn_layer.name, "weights_clipping_values")
            values = self._get_values(hn_layer)
            self.layer_to_clip[hn_layer.name] = (key, values)

    @property
    def weights_clip_params(self):
        return {k: v for layer, (k, v) in self.layer_to_clip.items()}

    def remove_unsupported_layers(self, is_supported_cb):
        unsupported_layers = set()
        for layer in self.layer_to_clip:
            layer_name = layer
            key, values = self.layer_to_clip[layer]
            if not is_supported_cb(layer_name, key, values):
                self._logger.debug(f"Skipping weight clipping for {layer} (isn't supported)")
                unsupported_layers.add(layer_name)

        for layer in unsupported_layers:
            del self.layer_to_clip[layer]

    def _get_real_mode(self, hn_layer):
        layer_cfg = self._clip_cfg.layers.get(hn_layer.name, LayerWeightsClippingConfig.get_default())
        mode = layer_cfg.mode
        if mode == WeightsClippingMode.mmse_if4b:
            precision_mode = hn_layer.precision_config.precision_mode
            if hn_layer.requires_native_weights and precision_mode.weight_bits() == 4:
                mode = WeightsClippingMode.mmse
            else:
                mode = WeightsClippingMode.disabled
        return mode

    def _get_values(self, hn_layer):
        mode = self._get_real_mode(hn_layer)
        layer_cfg = self._clip_cfg.layers.get(hn_layer.name, LayerWeightsClippingConfig.get_default())
        clip_values = layer_cfg.clipping_values
        layer_name = hn_layer.name

        if mode == WeightsClippingMode.manual:
            clip_values = clip_values
        elif mode == WeightsClippingMode.percentile:
            kernel = self._params[get_param_key(layer_name, "kernel")]
            clip_values = np.percentile(kernel, clip_values)
        elif mode == WeightsClippingMode.mmse:
            kernel = self._params[get_param_key(layer_name, "kernel")]
            precision_mode = hn_layer.precision_config.precision_mode
            clip_value = mmse(kernel, bits=self.WEIGHTS_BITS[precision_mode])
            clip_values = [-clip_value, clip_value]
        else:
            raise ValueError(f"operation value is not support {mode}")

        output_features = hn_layer.output_features
        return [[clip_values[0]] * output_features, [clip_values[1]] * output_features]
