from copy import deepcopy
from enum import Enum

from pydantic.v1.error_wrappers import ErrorWrapper, ValidationError

from hailo_model_optimization.acceleras.model_optimization_config.mo_config import (
    ModelOptimizationConfig,
    update_nested,
)
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerTranslationConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationClippingMode,
    BiasMode,
    PrecisionMode,
    SEOptimizationMethod,
    TiledSqueezeAndExciteMode,
    WeightsClippingMode,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import ConfigurationError
from hailo_sdk_common.hailo_nn.hn_definitions import HWLayerType, LayerType, ResizeMethod
from hailo_sdk_common.logger.logger import default_logger


class DoubleScaleBiasModeSupport(Enum):
    ALWAYS = "always"
    ALWAYS_ALLOWED = "always_allowed"
    NOT_ALLOWED = "not_allowed"


def apply_quantization_config_to_hn(hn, config: ModelOptimizationConfig):
    """
    Used to apply the entire config on the hn.
    It has been simplified to apply only the quantization groups (for params sorting)
    need to be removed completely in the near future
    """
    layer_specific_commands = config.precision_config.layers
    for layer_name, prec_cfg in layer_specific_commands.items():
        layer = hn.get_layer_by_name(layer_name)
        if prec_cfg.quantization_groups is not None:
            layer.precision_config.quantization_groups = prec_cfg.quantization_groups


def verify_16bit_ratio_with_bias_commands(config):
    if config.compression_params.auto_16bit_weights_ratio != 1:
        return

    # in case of 16 bit for the entire network, verify that no bias mode is specified by the user, in any layer
    for layer, layer_config in config.precision_config.layers.items():
        if layer_config.bias_mode is not None:
            raise ConfigurationError(
                f"auto_16bit_weights_ratio is set to 1. Layer bias mode configuration is not supported. Found config for {layer}. Please remove the bias mode command from the alls script",
                loc=["layers_config"],
            )


def verify_commands(hn_model, commands, custom_default=None, allocation_mode=False, pre_quantization_mode=False):
    if custom_default is not None:
        custom_default = deepcopy(custom_default)
        commands = update_nested(custom_default, commands)

    try:
        verifier = ModelConfigurationValidator(hn_model, commands, allocation_mode, pre_quantization_mode)
        verifier.validate()
    except ValidationError as e:
        for error in e.errors():
            update_error_meta(error, commands)
        raise
    return verifier.config


def update_error_meta(error, commands):
    current_value = commands
    depth = 3 if error["loc"][0] == "layers" else 1
    for key in error["loc"][:depth]:
        current_value = current_value[key]
    meta = current_value.get("meta", None)
    if meta is not None:
        loc = error["loc"][-1]
        lines = []
        commands = set()
        if loc == "__root__":
            for k in meta:
                lines.append(f"{k} from line {meta[k][0]}")
                commands.add(meta[k][1])
        elif loc in meta:
            meta = meta[loc]
            lines.append(f"{loc} from line {meta[0]}")
            commands.add(meta[1])

        be_verb = " was" if len(commands) == 1 else "s were"
        msg_extension = f"; {', '.join(lines)}; command{be_verb} {', '.join(commands)}"
        error["msg"] += msg_extension


class ModelConfigurationValidator:
    SUPPORT_16BIT_BIAS = [
        LayerType.conv,
        LayerType.dense,
        LayerType.dw,
        LayerType.batch_norm,
        LayerType.matmul,
        LayerType.deconv,
        LayerType.activation,
        LayerType.normalization,
        LayerType.ew_mult,
        LayerType.feature_multiplier,
        LayerType.resize,
        LayerType.const_input,
        LayerType.reduce_sum,
    ]
    ALWAYS_USE_16BIT_BIAS = [LayerType.ew_add, LayerType.ew_sub, LayerType.bbox_decoder, LayerType.fused_bbox_decoder]

    def __init__(self, hn_model, commands, allocation_mode, pre_quantization_mode):
        self._hn_model = hn_model
        if commands is None:
            commands = {}
        self._config = ModelOptimizationConfig(**commands)
        self._allocation_mode = allocation_mode
        self._pre_quantization_mode = pre_quantization_mode

    @property
    def validators(self):
        validators = []
        if self._pre_quantization_mode:
            validators.append(self.validate_tse_layers)
        if not self._allocation_mode:
            validators.extend(
                [
                    self.validate_finetune_layers,
                    self.validate_activation_clipping,
                    self.validate_weights_clipping,
                    self.validate_precision_config,
                    self.validate_translation_config,
                ],
            )
        else:
            validators.extend(
                [
                    self.validate_no_precision_config_conflicts,
                ],
            )
        return validators

    def get_global_config_errors(self):
        errors = []
        for validator in self.validators:
            try:
                validator()
            except ConfigurationError as exc:
                errors.append(ErrorWrapper(exc, (*exc.loc,)))
        return errors

    def validate_activation_clipping(self):
        for lname, layer_cfg in self._config.activation_clipping.layers.items():
            hn_layer = self._hn_model.get_layer_by_name(lname)
            self._validate_layer_activation_clipping(hn_layer, layer_cfg)

    def _validate_layer_activation_clipping(self, hn_layer, act_clipping_config):
        if act_clipping_config is None:
            return
        if act_clipping_config.mode == ActivationClippingMode.disabled:
            return
        has_activation = self._has_activation_with_verify(hn_layer, act_clipping_config)
        if not has_activation:
            return

    def _has_activation_with_verify(self, hn_layer, act_clipping_config):
        has_activation = hasattr(hn_layer, "activation")
        if has_activation:
            return True
        errmsg = f"Layer {hn_layer.full_name} has no activation"
        loc = ["activation_clipping", "mode"]
        if self.is_glob(act_clipping_config, "mode"):
            act_clipping_config.mode = ActivationClippingMode.disabled
            act_clipping_config.clipping_values = None
            return False
        else:
            raise ConfigurationError(errmsg, loc=loc)

    def validate_weights_clipping(self):
        for lname, layer_cfg in self._config.weights_clipping.layers.items():
            hn_layer = self._hn_model.get_layer_by_name(lname)
            self._validate_layer_weights_clipping(hn_layer, layer_cfg)

    def _validate_layer_weights_clipping(self, hn_layer, weights_clipping_config):
        if weights_clipping_config is None:
            return
        if weights_clipping_config.mode in {WeightsClippingMode.disabled, WeightsClippingMode.mmse_if4b}:
            return
        if hn_layer.requires_native_weights:
            return
        errmsg = f"Layer {hn_layer.full_name} has no weights"
        loc = ["weights_clipping", "mode"]
        if self.is_glob(weights_clipping_config, "mode"):
            weights_clipping_config.mode = WeightsClippingMode.disabled
            weights_clipping_config.clipping_values = None
            return
        else:
            raise ConfigurationError(errmsg, loc=loc)

    def validate_tse_layers(self):
        config = self._config.se_optimization
        if (
            (config is None)
            or (config.method != SEOptimizationMethod.tse)
            or (config.mode != TiledSqueezeAndExciteMode.custom)
        ):
            return
        invalid_layers = []
        for layer in config.layers:
            hn_layer = self._hn_model.get_layer_by_name(layer)
            if not (hn_layer.op == LayerType.avgpool and hn_layer.is_global_avg_pool()):
                invalid_layers.append(layer)
        if len(invalid_layers) > 0:
            raise ConfigurationError(
                f"The following layers are not global avgpool: {invalid_layers}",
                loc=["se_optimization", "layers"],
            )

    def validate_finetune_layers(self):
        if self._config.finetune is None:
            return
        if self._config.finetune.layers_to_freeze is not None:
            self._config.finetune.layers_to_freeze = [
                self._hn_model.get_layer_by_name(layer).full_name for layer in self._config.finetune.layers_to_freeze
            ]
        if self._config.finetune.loss_layer_names is not None:
            self._config.finetune.loss_layer_names = [
                self._hn_model.get_layer_by_name(layer).full_name for layer in self._config.finetune.loss_layer_names
            ]
        if self._config.finetune.native_layers is not None:
            self._config.finetune.native_layers = [
                self._hn_model.get_layer_by_name(layer).full_name for layer in self._config.finetune.native_layers
            ]

    def validate_precision_config(self):
        for lname, layer_precision_config in self._config.precision_config.layers.items():
            self._validate_precision_mode(lname, layer_precision_config)
            self._validate_bias_mode(lname, layer_precision_config)
            self._validate_quantization_groups(lname, layer_precision_config)

    def _validate_precision_mode(self, lname, layer_precision_config: LayerPrecisionConfig):
        precision_mode = layer_precision_config.precision_mode
        if precision_mode is None:
            return

        layer_support_16bit = self._check_16bit_bias_support(lname)
        if precision_mode != PrecisionMode.a8_w4 or layer_support_16bit != DoubleScaleBiasModeSupport.NOT_ALLOWED:
            return
        elif layer_support_16bit == DoubleScaleBiasModeSupport.NOT_ALLOWED:
            errmsg = (
                f"Layer {lname} doesn't support precision_mode={PrecisionMode.a8_w4.value} "
                f"because it does not support double scale bias"
            )
            loc = ["precision_config", "precision_mode"]
            if self.is_glob(layer_precision_config, "precision_mode"):
                layer_precision_config.precision_mode = None
            else:
                raise ConfigurationError(errmsg, loc=loc)
        else:
            raise NotImplementedError(
                f"Reached unexpected scenario in precision_mode validation\n"
                f"precision_mode={precision_mode.value}, layer_support_16bit={layer_support_16bit}",
            )

    def _validate_bias_mode(self, lname, layer_precision_config: LayerPrecisionConfig):
        bias_mode = layer_precision_config.bias_mode
        if bias_mode is None:
            return

        bias_16bit_support = self._check_16bit_bias_support(lname)
        is_bias_single_scale = bias_mode == BiasMode.single_scale_decomposition
        errmsg = None
        if is_bias_single_scale and bias_16bit_support == DoubleScaleBiasModeSupport.ALWAYS:
            errmsg = f"Layer {lname} always uses 16 bit bias therefore it cannot be disabled"
        elif not is_bias_single_scale and bias_16bit_support == DoubleScaleBiasModeSupport.NOT_ALLOWED:
            errmsg = f"Layer {lname} does not support 16 bit bias, therefore it is not used"

        if errmsg is None:
            return

        loc = ["precision_config", "bias_mode"]
        if self.is_glob(layer_precision_config, "bias_mode"):
            layer_precision_config.bias_mode = None
        else:
            raise ConfigurationError(errmsg, loc=loc)

    def _validate_quantization_groups(self, lname, layer_precision_config: LayerPrecisionConfig):
        hn_layer = self._hn_model.get_layer_by_name(lname)

        if layer_precision_config.quantization_groups == -1:
            layer_precision_config.quantization_groups = None
            return
        if self._layer_support_quantization_groups(hn_layer):
            return
        if layer_precision_config.quantization_groups is None or layer_precision_config.quantization_groups == 1:
            return
        errmsg = (
            f"Can't set quantization_groups to be {layer_precision_config.quantization_groups} on layer {lname} because "
            f"it is of type {hn_layer.op.name}"
        )
        loc = ["precision_config", "quantization_groups"]
        if self.is_glob(layer_precision_config, "quantization_groups"):
            layer_precision_config.quantization_groups = None
        else:
            raise ConfigurationError(errmsg, loc=loc)

    def _check_16bit_bias_support(self, lname):
        """
        "
        check the type of support for 16 bit bias:
        Always = has to be double scale bias mode
        Always allowed = can be either double or single mode(won't change due to allocator stuff)
        not_allowed = don't support 16 bit bias at all
        """
        hn_layer = self._hn_model.get_layer_by_name(lname)
        if hn_layer.op in self.ALWAYS_USE_16BIT_BIAS:
            return DoubleScaleBiasModeSupport.ALWAYS
        if hn_layer.op in self.SUPPORT_16BIT_BIAS:
            if hn_layer.op == LayerType.resize:
                return self._check_resize_16bit_support(hn_layer)
            else:
                return DoubleScaleBiasModeSupport.ALWAYS_ALLOWED
        else:
            return DoubleScaleBiasModeSupport.NOT_ALLOWED

    def _check_resize_16bit_support(self, hn_layer):
        if hn_layer.resize_method == ResizeMethod.bilinear and not (
            HWLayerType.ppu in hn_layer.compilation_params.get("hw_layer_type_list", [])
            and hn_layer.is_bilinear_align_corners
        ):
            return DoubleScaleBiasModeSupport.ALWAYS_ALLOWED
        else:
            return DoubleScaleBiasModeSupport.NOT_ALLOWED

    def _layer_support_quantization_groups(self, hn_layer):
        elwa = hn_layer.ew_add_enabled
        return hn_layer.op in {LayerType.conv, LayerType.deconv, LayerType.dense} and not elwa

    def validate_translation_config(self):
        for lname, layer_translation_config in self._config.translation_config.layers.items():
            self._validate_force_range_in(lname, layer_translation_config)
            self._validate_force_range_out(lname, layer_translation_config)

    def _validate_force_range_in(self, lname, layer_translation_config: LayerTranslationConfig):
        hn_layer = self._hn_model.get_layer_by_name(lname)
        if self._layer_has_stats_collection(hn_layer):
            return
        if layer_translation_config.force_range_in is None:
            return

        errmsg = f"Can't force input ranges on layer {hn_layer.full_name} because it is of type {hn_layer.op.name}"
        loc = ["translation_config", "force_range_in"]
        if self.is_glob(layer_translation_config, "force_range_in"):
            layer_translation_config.force_range_in = None
        else:
            raise ConfigurationError(errmsg, loc=loc)

    def _validate_force_range_out(self, lname, layer_translation_config: LayerTranslationConfig):
        hn_layer = self._hn_model.get_layer_by_name(lname)
        if self._layer_has_stats_collection(hn_layer) or hn_layer.op in [LayerType.input_layer, LayerType.const_input]:
            return
        if layer_translation_config.force_range_out is None:
            return

        errmsg = f"Can't force output ranges on layer {hn_layer.full_name} because it is of type {hn_layer.op.name}"
        loc = ["translation_config", "force_range_out"]
        if self.is_glob(layer_translation_config, "force_range_out"):
            layer_translation_config.force_range_out = None
        else:
            raise ConfigurationError(errmsg, loc=loc)

    def _validate_force_range_preact(self, lname, layer_translation_config: LayerTranslationConfig):
        hn_layer = self._hn_model.get_layer_by_name(lname)
        if self.layer_has_preact_stats(hn_layer):
            return
        if layer_translation_config.force_range_preact is None:
            return

        errmsg = f"Can't force output ranges on layer {hn_layer.full_name} because it is of type {hn_layer.op.name}"
        loc = ["translation_config", "force_range_preact"]
        if self.is_glob(layer_translation_config, "force_range_preact"):
            layer_translation_config.force_range_preact = None
        else:
            raise ConfigurationError(errmsg, loc=loc)

    @staticmethod
    def layer_has_preact_stats(hn_layer):
        return hn_layer.op in {
            LayerType.conv,
            LayerType.dw,
            LayerType.avgpool,
            LayerType.deconv,
            LayerType.dense,
            LayerType.batch_norm,
            LayerType.normalization,
            LayerType.bbox_decoder,
            LayerType.fused_bbox_decoder,
            LayerType.ew_add,
            LayerType.ew_sub,
            LayerType.ew_mult,
            LayerType.feature_multiplier,
            LayerType.activation,
            LayerType.reduce_sum,
            LayerType.matmul,
        }

    def _layer_has_stats_collection(self, hn_layer):
        return hn_layer.op in {
            LayerType.conv,
            LayerType.dw,
            LayerType.avgpool,
            LayerType.deconv,
            LayerType.dense,
            LayerType.batch_norm,
            LayerType.normalization,
            LayerType.bbox_decoder,
            LayerType.fused_bbox_decoder,
            LayerType.ew_add,
            LayerType.ew_sub,
            LayerType.ew_mult,
            LayerType.feature_multiplier,
            LayerType.activation,
            LayerType.reduce_sum,
            LayerType.matmul,
        }

    @staticmethod
    def layer_has_force_shift(hn_layer):
        return hn_layer.op in {
            LayerType.conv,
            LayerType.dw,
            LayerType.avgpool,
            LayerType.deconv,
            LayerType.dense,
            LayerType.batch_norm,
            LayerType.normalization,
            LayerType.ew_add,
            LayerType.ew_sub,
            LayerType.ew_mult,
            LayerType.feature_multiplier,
            LayerType.activation,
            LayerType.reduce_sum,
            LayerType.matmul,
        }

    def _validate_force_shift(self, lname, layer_translation_config: LayerTranslationConfig):
        hn_layer = self._hn_model.get_layer_by_name(lname)
        if self.layer_has_force_shift(hn_layer):
            return
        if layer_translation_config.force_shift is None:
            return

        errmsg = f"Can't force output ranges on layer {hn_layer.full_name} because it is of type {hn_layer.op.name}"
        loc = ["translation_config", "force_shift"]
        if self.is_glob(layer_translation_config, "force_shift"):
            layer_translation_config.force_shift = None
        else:
            raise ConfigurationError(errmsg, loc=loc)

    def _layer_has_preact_stats(self, hn_layer):
        return hn_layer.op in {
            LayerType.conv,
            LayerType.dw,
            LayerType.avgpool,
            LayerType.deconv,
            LayerType.dense,
            LayerType.batch_norm,
            LayerType.normalization,
            LayerType.bbox_decoder,
            LayerType.fused_bbox_decoder,
            LayerType.ew_add,
            LayerType.ew_sub,
            LayerType.ew_mult,
            LayerType.feature_multiplier,
            LayerType.activation,
            LayerType.reduce_sum,
            LayerType.matmul,
            LayerType.resize,
        }

    def validate_no_precision_config_conflicts(self):
        for lname, layer_precision_config in self._config.precision_config.layers.items():
            conflicting_keys = {}
            hn_layer = self._hn_model.get_layer_by_name(lname)
            for key in layer_precision_config:
                existing_value = getattr(hn_layer.precision_config, key)
                new_value = getattr(layer_precision_config, key)
                if new_value is None:
                    continue
                if new_value != existing_value:
                    conflicting_keys[key] = (new_value, existing_value)

            if len(conflicting_keys) > 0:
                conflicts = [
                    f"{key} - new value {new_v} conflicts with exiting value {old_v}"
                    for key, (new_v, old_v) in conflicting_keys.items()
                ]
                conflicts_str = "\n".join(conflicts)
                errmsg = (
                    f"The following quantization params have different value than those previously set for "
                    f"layer '{hn_layer.full_name}':\n"
                    f"{conflicts_str}\n"
                    f"This modification is not allowed in allocation mode."
                )
                raise ConfigurationError(errmsg, loc=["precision_config"])

    def validate(self):
        global_errors = self.get_global_config_errors()
        errors = global_errors
        if len(errors) > 0:
            raise ValidationError(errors, type(self._config))

    @property
    def config(self):
        return self._config

    @staticmethod
    def is_glob(config, field):
        return False if config.meta is None or field not in config.meta else config.meta[field].is_glob

    def glob_warning(self, layer_config, msg, loc):
        # TODO: fix glob warning and add after is_glob checks
        warning_data = {"loc": loc, "msg": msg}
        update_error_meta(warning_data, layer_config.dict())
        default_logger().debug(f"Ignoring config (derived from glob); {warning_data['msg']}")
