import copy
import fnmatch
from abc import ABC, abstractmethod
from collections import OrderedDict
from enum import Enum
from typing import Dict

from hailo_model_optimization.acceleras.model_optimization_config.mo_config import update_nested
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerPrecisionConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    LAYERS_KEY,
    PrecisionMode,
    QuantizationDeprecatedParam,
    Use16bitBiasPolicies,
    Use16bitBiasPolicyToBiasMode,
)
from hailo_sdk_client.sdk_backend.script_parser.commands import CommandsGroups, ModelScriptCommand, SupportedCommands
from hailo_sdk_client.sdk_backend.sdk_backend_exceptions import (
    AllocatorScriptParserException,
    QuantizationScriptParserException,
)
from hailo_sdk_common.logger.logger import DeprecationVersion, default_logger


def _value_to_str(values):
    # TODO: consider replacing the raw dict normal to dict and convert the enum here.
    if isinstance(values, dict):
        values = sorted(values.keys())

    if isinstance(values, (list, tuple)):
        if len(values) == 1 and _has_glob(values[0]):
            as_str = f"{{{values[0]}}}"
        else:
            as_str = "[" + ", ".join(f"{_value_to_str(v)}" for v in values) + "]"
    else:
        as_str = str(_float_to_int_if_valid(values))

    return as_str


def _float_to_int_if_valid(v):
    if not isinstance(v, float):
        return v
    if v == int(v):
        return int(v)
    return v


def _has_glob(v):
    if not isinstance(v, str):
        return False

    return "*" in v


def get_param_value_from_policy(value, policies):
    if value in policies:
        return policies[value]
    elif value in [True, "True", "true"]:
        return policies["enabled"]
    elif value in [False, "False", "false"]:
        return policies["disabled"]
    return None


class ModelOptimizationCommand(ModelScriptCommand, ABC):
    def __init__(self, function_name, cmd_str, loc, function_args=None):
        self._command_string = cmd_str
        self._command_line = loc
        super().__init__(function_name, function_args=function_args)

    @abstractmethod
    def __str__(self):
        # str must iterate over the params to support 'add_scope' functionality
        # So this function can't simply use self._command_string...
        pass

    def _build_meta(self, key, is_glob):
        return {key: (self._command_line, self._command_string, is_glob)}

    @property
    def group(self):
        return CommandsGroups.QUANTIZATION

    @abstractmethod
    def export(self):
        pass

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        pass

    def _expand_glob_for_key(self, layers_from_script, layers_in_hn, net_scopes) -> Dict[str, bool]:
        if isinstance(layers_from_script, dict):
            # in case of dict, the layers has already been handled, return as is.
            return layers_from_script
        if not isinstance(layers_from_script, list):
            layers_from_script = [layers_from_script]
        return self._get_layers_glob_syntax(layers_from_script, layers_in_hn, net_scopes)

    def _get_layers_glob_syntax(self, layers, layers_in_hn, net_scopes) -> Dict[str, bool]:
        layers_not_in_hn = []
        layer_to_is_glob_derived_dict = {}
        for pattern in layers:
            if len(net_scopes) == 1:
                pattern = self.add_scope_to_layer(net_scopes, pattern)
            layers_to_add = fnmatch.filter(layers_in_hn, pattern)
            if len(layers_to_add) == 0 and "*" not in pattern:
                layers_not_in_hn.append(pattern)
            curr_layer_is_glob_derived = self._get_is_glob_derived(
                layers_to_add,
                pattern,
                layer_to_is_glob_derived_dict,
            )
            layer_to_is_glob_derived_dict.update(curr_layer_is_glob_derived)
        self._verify_valid_layers(layers_not_in_hn, layer_to_is_glob_derived_dict.keys(), layers)
        return layer_to_is_glob_derived_dict

    @staticmethod
    def _get_is_glob_derived(layers_to_add, pattern, is_glob_derived):
        def is_glob(layer):
            return is_glob_derived.get(layer, True) and ("*" in pattern)

        return {layer: is_glob(layer) for layer in layers_to_add}

    def _verify_valid_layers(self, layers_not_in_hn, unfolded_layers, patterns):
        if len(layers_not_in_hn) > 0:
            default_logger().warning(
                f"line {self._command_line}: layers {layers_not_in_hn} could not be found in scope.",
            )
        if len(unfolded_layers) == 0:
            default_logger().warning(
                f"line {self._command_line}: None of the layers {patterns} were found in the hn",
            )

    @abstractmethod
    def add_scope(self, scope_names, force=False):
        # If command has layers, add_scope has to be implemented
        pass

    def _update_meta(self, data, is_glob):
        for v in data.values():
            if isinstance(v, dict):
                if "meta" in v:
                    self._update_meta_single_field(v, is_glob)
                else:
                    self._update_meta(v, is_glob)

    def _update_meta_single_field(self, cfg_entry, is_glob):
        meta_dict = {}
        for k1 in filter(lambda x: x != "meta", cfg_entry.keys()):
            meta_dict.update(self._build_meta(k1, is_glob))
        cfg_entry["meta"] = meta_dict


class ModelOptimizationFlavorCommand(ModelOptimizationCommand):
    def __init__(self, cmd_str, loc, function_args=None):
        function_name = SupportedCommands.MODEL_OPTIMIZATION_FLAVOR.value
        super().__init__(function_name, cmd_str, loc, function_args=function_args)

    def __str__(self):
        # This command shouldn't have any layers, so a scope change shouldn't affect the command
        return self._command_string

    @property
    def group(self):
        return CommandsGroups.QUANTIZATION_FLAVOR

    def add_scope(self, scope_names, force=False):
        # No layers in kwargs
        pass

    def export(self):
        flavor_config = dict(self.function_args)
        flavor_config["meta"] = None
        cfg = {"flavor": flavor_config}
        self._update_meta(cfg, False)
        return flavor_config

    def get_layers(self):
        return []

    def validate_command(self, layers_scope_from_hn):
        return

    @classmethod
    def from_tokens(cls, tokens, cmd_str, loc):
        args = tokens.function_args.asList()
        kwargs = cls._build_kwargs(args)
        return cls(cmd_str, loc, function_args=kwargs)

    @classmethod
    def _build_kwargs(cls, function_args):
        kwargs = OrderedDict()
        for arg_dict in function_args:
            if arg_dict.keys() & kwargs.keys():
                repeating_keys = arg_dict.keys() & kwargs.keys()
                command_name = SupportedCommands.PRE_QUANTIZATION_OPTIMIZATION.value
                raise QuantizationScriptParserException(
                    f"Command {command_name}(...) had repeating key: {repeating_keys}",
                )
            kwargs.update(arg_dict)
        return kwargs


class FeatureOptimizationCommand(ModelOptimizationCommand, ABC):
    def __init__(self, function_name, feature, cmd_str, cmd_line, kwargs):
        self._feature = feature
        super().__init__(function_name, cmd_str, cmd_line, kwargs)

    def __str__(self):
        feature_with_args = f"{self.feature.value}"
        if self.function_args:
            kwargs = ", ".join(f"{k}={_value_to_str(v)}" for k, v in sorted(self.function_args.items()))
            feature_with_args += f", {kwargs}"

        return f"{self.function_name.value}({feature_with_args})"

    @property
    def feature(self):
        return self._feature

    @classmethod
    def _build_kwargs(cls, function_args):
        feature = function_args[0]
        kwargs = OrderedDict()
        for arg_dict in function_args[1:]:
            if arg_dict.keys() & kwargs.keys():
                repeating_keys = arg_dict.keys() & kwargs.keys()
                command_name = SupportedCommands.PRE_QUANTIZATION_OPTIMIZATION.value
                raise QuantizationScriptParserException(
                    f"Command {command_name}({feature}, ...) had repeating key: {repeating_keys}",
                )
            kwargs.update(arg_dict)
        return feature, kwargs

    def has_layers(self):
        layers = self.function_args.get(LAYERS_KEY, [])
        return len(layers) > 0

    def _export_per_layer_config(self, new_format=False):
        core_config = self._export_core()
        per_layer_config = {}
        layer_cfgs = self.function_args[LAYERS_KEY]
        for layer in layer_cfgs:
            is_glob = layer_cfgs[layer] if isinstance(layer_cfgs, dict) else False
            current_layer_data = copy.deepcopy(core_config)
            self._update_meta(current_layer_data, is_glob)
            per_layer_config[layer] = current_layer_data
        if new_format:
            new_data = {lname: val[self.feature.value] for lname, val in per_layer_config.items()}
            return {self.feature.value: {LAYERS_KEY: new_data}}
        else:
            return {LAYERS_KEY: per_layer_config}

    def _export_global_config(self):
        core_config = self._export_core(False)
        self._update_meta(core_config, False)
        return core_config

    def _export_core(self, skip_layers=True):
        if skip_layers:
            keys = filter(lambda x: x != LAYERS_KEY, self.function_args.keys())
        else:
            keys = self.function_args.keys()
        config_export = {key: self.function_args[key] for key in keys}
        config_export["meta"] = None
        return {self.feature.value: config_export}

    def get_layers(self):
        return []

    def validate_command(self, layers_scope_from_hn):
        return

    def _add_scope_to_key(self, scope_name, key, force):
        layers = self.function_args.get(key, None)
        if layers is None:
            return
        if isinstance(layers, str):
            layers = [layers]

        if isinstance(layers, dict):
            self.function_args[key] = {
                self.add_scope_to_layer(scope_name, layer, force=force): is_glob for layer, is_glob in layers.items()
            }
        elif isinstance(layers, list):
            self.function_args[key] = [self.add_scope_to_layer(scope_name, layer, force=force) for layer in layers]
        else:
            raise NotImplementedError("Unexpected layers object")


class QuantizationParamCommand(ModelOptimizationCommand):
    LEGACY_KEYS = {"use_16bit_bias", "use_4bit_weights", "exponential_mode_4bit_weights"}

    def __init__(self, input_layers, cmd_str, loc, function_args=None, **quantization_params_to_parse):
        super().__init__(SupportedCommands.QUANTIZATION_PARAM, cmd_str, loc, function_args=function_args)
        self._command_string = cmd_str
        self._command_line = loc
        self._input_layers = input_layers
        self._orig_input_layers = input_layers
        self._quantization_params = quantization_params_to_parse

    def __str__(self):
        quantization_params = ", ".join(
            [f"{k}={_value_to_str(v)}" for k, v in sorted(self.quantization_params.items())],
        )
        layers = _value_to_str(self._input_layers)
        return f"{self.function_name.value}({layers}, {quantization_params})"

    @classmethod
    def from_tokens(cls, tokens, cmd_str, loc):
        args = tokens.function_args.asList()
        folded_layers = [args[0]] if not isinstance(args[0], list) else args[0]
        quantization_params_to_parse = OrderedDict()

        for param in args[1:]:
            for key, value in param.items():
                if isinstance(value, list) and len(value) == 1:
                    param[key] = value[0]

            quantization_params_to_parse.update(param)

        return cls(folded_layers, cmd_str, loc, function_args=args, **quantization_params_to_parse)

    def expand_glob(self, layers_scope_from_hn, net_scopes):
        self._input_layers = self._expand_glob_for_key(self._orig_input_layers, layers_scope_from_hn, net_scopes)

    @property
    def input_layers(self):
        return self._input_layers

    @property
    def quantization_params(self):
        return self._quantization_params

    def get_layers(self):
        return self._input_layers

    def validate_command(self, layers_scope_from_hn):
        if self._input_layers is None or self.has_unfound_layers(layers_scope_from_hn):
            not_found = [x for x in self._input_layers if x not in layers_scope_from_hn]
            raise AllocatorScriptParserException(
                self.msg_prefix + f"Cannot find layer {not_found} in existing layers scope.",
            )

    def add_scope(self, scope_names, force=False):
        if isinstance(self._input_layers, list):
            self._input_layers = [
                self.add_scope_to_layer(scope_names, layer, force=force) for layer in self._input_layers
            ]
        elif isinstance(self._input_layers, dict):
            self._input_layers = {
                self.add_scope_to_layer(scope_names, layer, force=force): is_glob
                for layer, is_glob in self._input_layers.items()
            }
        else:
            raise NotImplementedError("Unexpected layers object")

    def _export_translation_config(self):
        layer_data = {}
        for key in self.quantization_params:
            if (key in self.LEGACY_KEYS) or (key in LayerPrecisionConfig.keys()):
                # add all the keys that haven't been added so far
                continue
            val = self.quantization_params[key]
            if isinstance(val, Enum):
                val = val.value
            layer_data[key] = val
        if layer_data:
            layer_data["meta"] = None
        return layer_data

    def _export_precision_config(self):
        layer_data = {}
        for key in self.quantization_params:
            if key not in LayerPrecisionConfig.keys():
                continue
            val = self.quantization_params[key]
            if isinstance(val, Enum):
                val = val.value
            layer_data[key] = val
        if layer_data:
            layer_data["meta"] = None
        return layer_data

    def _export_deprecated_params(self):
        data = {}
        data.update(self._export_use_16bit_bias())
        data.update(self._export_use_4bit())
        conflicting_keys = data.keys() & self.quantization_params.keys()
        if conflicting_keys:
            raise QuantizationScriptParserException(
                f"One of the deprecated fields had conflicting configurations with {conflicting_keys}",
            )
        return data

    def _export_use_4bit(self):
        use_4bit_weights_key = QuantizationDeprecatedParam.use_4bit_weights
        use_4bit_weights_exp_key = QuantizationDeprecatedParam.exponential_mode_4bit_weights
        use_4bit_weights = self.quantization_params.get(use_4bit_weights_key, None)
        use_4bit_weights_exp = self.quantization_params.get(use_4bit_weights_exp_key, None)
        if use_4bit_weights is not None:
            # TODO: https://hailotech.atlassian.net/browse/SDK-32495
            default_logger().deprecation_warning(
                f"quantization_param's {use_4bit_weights_key} argument will be deprecated in the near "
                f"future. Please use `precision_mode` argument instead. (line {self._command_line})",
                DeprecationVersion.APR2022,
            )
        if use_4bit_weights_exp is not None:
            # TODO: https://hailotech.atlassian.net/browse/SDK-32495
            default_logger().deprecation_warning(
                f"quantization_param's {use_4bit_weights_exp_key} argument will be deprecated in the "
                f"near future. "
                f"Please use `precision_mode` argument instead. (line {self._command_line})",
                DeprecationVersion.APR2022,
            )
        if use_4bit_weights and use_4bit_weights_exp:
            raise QuantizationScriptParserException(
                f"Invalid command in line {self._command_line}, can't set both "
                f"{use_4bit_weights_key}=True and {use_4bit_weights_exp_key}=True",
            )
        data = {}
        if use_4bit_weights:
            data = {"precision_mode": PrecisionMode.a8_w4.value, "meta": None}
        if use_4bit_weights_exp:
            data = {"precision_mode": PrecisionMode.a8_w4_exp.value, "meta": None}
        return data

    def _export_use_16bit_bias(self):
        data = {}
        use_16bias_key = QuantizationDeprecatedParam.use_16bit_bias
        if use_16bias_key not in self.quantization_params:
            return data
        # TODO: https://hailotech.atlassian.net/browse/SDK-32495
        default_logger().deprecation_warning(
            f"quantization_param's {use_16bias_key} argument will be deprecated in the near future. "
            f"Please use `bias_mode` argument instead. (line {self._command_line})",
            DeprecationVersion.APR2022,
        )
        use_16bit_bias = self.quantization_params[use_16bias_key]
        value = get_param_value_from_policy(use_16bit_bias, Use16bitBiasPolicies)
        bias_mode = Use16bitBiasPolicyToBiasMode[value]
        data.update({"bias_mode": bias_mode.value, "meta": None})
        return data

    def export(self):
        layer_translation_config = self._export_translation_config()
        layer_precision_config = {}
        update_nested(layer_precision_config, self._export_deprecated_params())
        update_nested(layer_precision_config, self._export_precision_config())
        precision_config = {"layers": {}}
        translation_config = {"layers": {}}
        for layer in self._input_layers:
            is_glob = self._input_layers[layer] if isinstance(self._input_layers, dict) else False
            curr_layer_prec_cfg = copy.deepcopy(layer_precision_config)
            self._update_meta_single_field(curr_layer_prec_cfg, is_glob)
            precision_config["layers"][layer] = curr_layer_prec_cfg

            curr_layer_translation_cfg = copy.deepcopy(layer_translation_config)
            self._update_meta_single_field(curr_layer_translation_cfg, is_glob)
            translation_config["layers"][layer] = curr_layer_translation_cfg
        return {
            "precision_config": precision_config,
            "translation_config": translation_config,
        }


class CompressionParamsCommands(ModelOptimizationCommand):
    # TODO: This looks like a legacy class, we should probably remove this
    def __init__(self, cmd_str, loc, **compression_params):
        super().__init__(SupportedCommands.COMPRESSION_PARAMS, cmd_str, loc, function_args=compression_params)

    @classmethod
    def from_tokens(cls, tokens, command_string, line_number):
        optimization_params = {}
        for param in tokens.function_args:
            optimization_params.update(param)
        return cls(command_string, line_number, **optimization_params)

    def get_layers(self):
        return []

    def validate_command(self, layers_scope_from_hn):
        return

    def export(self):
        auto_4bit_weights_ratio = self._function_args.get("auto_4bit_weights_ratio", None)
        if auto_4bit_weights_ratio is None:
            data = {}
        else:
            data = {
                "compression_params": {
                    "auto_4bit_weights_ratio": auto_4bit_weights_ratio,
                    "meta": self._build_meta("auto_4bit_weights_ratio", False),
                },
            }
        return data

    def add_scope(self, scope_names, force=False):
        pass

    def __str__(self):
        kwargs = ", ".join([f"{k}={_value_to_str(v)}" for k, v in sorted(self._function_args.items())])
        return f"{self.function_name.value}({kwargs})"
