import os
import re

import numpy as np
import yaml


class QnpzExporter:
    PARAMS_TEMPLATE = os.path.join(os.path.dirname(__file__), "export_quantize.yaml")
    DATATYPE_MAPPING = {
        "bool": np.bool_,
        "int8": np.int8,
        "int16": np.int16,
        "int32": np.int32,
        "uint8": np.uint8,
        "uint16": np.uint16,
        "uint32": np.uint32,
        "float32": np.float32,
        "float64": np.float64,
    }

    SPETIAL_KEYWORDS = ["OPTIONAL", "MACRO", "SHARED"]

    def __init__(self, class_name: str, params: dict) -> None:
        self.class_name = class_name
        self.template = self._init_template(params)
        self.params = params

    def _init_template(self, params: dict) -> dict:
        template = self.load_config()
        template = self._flatten_keys(template)
        template = self._fill_macro(template, params)
        return template

    def export(self, include_shared_weights=True) -> dict:
        self._remove_shared_weights(include_shared_weights=include_shared_weights)
        self._validate_params()
        return self.params

    def load_config(self):
        if not os.path.exists(self.PARAMS_TEMPLATE):
            raise FileNotFoundError(f"Template file {self.PARAMS_TEMPLATE} not found.")
        with open(self.PARAMS_TEMPLATE, "r") as file:
            config = yaml.safe_load(file)["Layers"][self.class_name]
        self.convert_dtype_strings(config)
        return config

    def convert_dtype_strings(self, config):
        for key, value in config.items():
            if key in self.SPETIAL_KEYWORDS:
                self.convert_dtype_strings(value)
            elif isinstance(value, list):
                config[key] = [self.DATATYPE_MAPPING.get(dtype_str, dtype_str) for dtype_str in value]
            elif isinstance(value, dict):
                self.convert_dtype_strings(value)

    def _fill_macro(self, template: dict, params: dict) -> dict:
        macro_vals = template.pop("MACRO", {})
        macro_vals.pop("OPTIONAL", {})
        macro_vals.pop("SHARED", {})
        placeholder_patterns = {
            "{n}": r"(\d+)",  # Matches numbers
            "{s}": r"(\w+)",  # Matches word characters
            "{any}": r"(.+)",  # Matches any character sequence
        }
        regex_patterns = []
        for macro, macro_value in macro_vals.items():
            regex_pattern = re.escape(macro)
            for placeholder, pattern in placeholder_patterns.items():
                # Use re.escape to handle any special characters in placeholders
                regex_pattern = regex_pattern.replace(re.escape(placeholder), pattern)
            compiled_pattern = re.compile(f"^{regex_pattern}$")
            regex_patterns.append((compiled_pattern, macro_value))

        # Match keys against regex patterns
        for key in params:
            for pattern, macro_value in regex_patterns:
                if pattern.match(key):
                    template[key] = macro_value
                    break  # Stop checking other patterns if a match is found
        return template

    def _flatten_keys(self, params: dict) -> dict:
        flat_params = {}
        for key, value in params.items():
            if key in self.SPETIAL_KEYWORDS:
                flat_params[key] = self._flatten_keys(value)
            elif isinstance(value, dict):
                flattened = self._flatten_keys(value)

                # Poping Macros and Optional keys
                spatial_keys = {key: flattened.pop(key, {}) for key in self.SPETIAL_KEYWORDS}

                flat_params.update(add_prefix(flattened, f"{key}/"))

                # Scalling up Optional and Macro keys
                for key, sub_keys in spatial_keys.items():
                    if sub_keys:
                        flat_params.setdefault(key, {}).update(add_prefix(sub_keys, f"{key}/"))

            else:
                flat_params[key] = value
        return flat_params

    def fill_macro(self, template: dict, params: dict) -> dict:
        macro_vals = template.pop("MACRO", {})
        placeholder_patterns = {
            "{n}": r"(\d+)",  # Matches numbers
            "{s}": r"(\w+)",  # Matches word characters
            "{any}": r"(.+)",  # Matches any character sequence
        }
        regex_patterns = []
        for macro, macro_value in macro_vals.items():
            regex_pattern = re.escape(macro)
            for placeholder, pattern in placeholder_patterns.items():
                # Use re.escape to handle any special characters in placeholders
                regex_pattern = regex_pattern.replace(re.escape(placeholder), pattern)
            compiled_pattern = re.compile(f"^{regex_pattern}$")
            regex_patterns.append((compiled_pattern, macro_value))

        # Match keys against regex patterns
        for key in params:
            for pattern, macro_value in regex_patterns:
                if pattern.match(key):
                    template[key] = macro_value
                    break  # Stop checking other patterns if a match is found
        return template

    def _remove_shared_weights(self, include_shared_weights=True):
        if include_shared_weights:
            self.template.update(self.template.pop("SHARED", {}))
            return
        shared_keys = self.template.pop("SHARED", {})
        for key in shared_keys:
            self.params.pop(key, None)

    def _validate_params(self):
        optional_vals = self.template.pop("OPTIONAL", {})

        extra_keys = self.params.keys() - self.template.keys()
        must_keys = self.params.keys() - extra_keys
        if self.template.keys() != must_keys:
            raise ValueError(f"Missing keys: {self.template.keys() - must_keys}, " f"Extra keys: {extra_keys}")
        if not extra_keys.issubset(optional_vals.keys()):
            raise ValueError(
                f"Extra keys: {extra_keys - optional_vals.keys()} not in optional keys: {optional_vals.keys()}"
            )

        template_compose = {**self.template, **optional_vals}

        for key in self.params:
            param_value = self.params[key]
            if param_value.dtype.type not in template_compose[key]:
                allowed_type_names = [dtype.__name__ for dtype in template_compose[key]]
                raise TypeError(
                    f"Parameter {key} has incorrect type. Allowed types: {allowed_type_names},"
                    f"given type: {param_value.dtype.type.__name__}"
                )


def add_prefix(original_dict: dict, prefix: str) -> dict:
    """Add a prefix to all the keys of a dictionary"""
    return {f"{prefix}{key}": val for key, val in original_dict.items()}


def add_suffix(original_dict: dict, suffix: str) -> dict:
    """Add a suffix to all the keys of a dictionary"""
    return {f"{key}{suffix}": val for key, val in original_dict.items()}
