from enum import Enum
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch.nn as nn

from hailo_model_optimization.acceleras.utils.acceleras_definitions import PrecisionMode, SplittedPrecisionMode
from hailo_model_optimization.saitama.framework.common.saitama_definitions import (
    APUPrecisionConfig,
    MACPrecisionConfig,
    QType,
)
from hailo_model_optimization.saitama.framework.common.saitama_module import SaitamaModule
from hailo_model_optimization.saitama.translators.base_translator import BaseLayerTranslator
from hailo_model_optimization.saitama.translators.translator_utils import KeyHandler, Tags


class LayerMode(Enum):
    NATIVE = "native"
    QUANT = "quant"
    AUTO = "auto"


class HNTranslator(BaseLayerTranslator):
    """
    Abstract translator that defines the core logic for converting HN elements into Saitama layers.

    The main steps are:

    1. ``translate(lname, hn_elemnt, acceleras_params)``:
       - Logs the translation process.
       - Calls ``_build_layer_from_hn(hn_elemnt)`` to create a Saitama `BaseLayer`.
       - Calls ``_build_weights_and_encodings(acceleras_params)`` to retrieve weights and encoding
         information from Acceleras parameters.
       - Invokes ``finalizr_layer(layer, weights, encoding_info)`` to import and enable quantization
         on the layer.
       - Moves the layer to the specified device and returns it.

    2. ``_build_quant_layer_from_hn(hn_elemnt)`` (per layer implementation):
       - Must be implemented by subclasses. Creates the specific Saitama `BaseLayer` (or subclass)
         needed for the HN element.

    3. ``_build_native_layer_from_hn(hn_elemnt)`` (per layer implementation):
       - Must be implemented by subclasses. Creates the specific Saitama `BaseLayer` (or subclass)
         needed for the HN element.

    4. ``_build_weights_and_encodings(acceleras_params)`` (optional):
       - Must be implemented by subclasses if the layer needs weights/encodings. Returns a tuple
         of two `TensorDict` objects: (weights, encoding_info).

    5. ``finalizr_layer(layer, weights, encoding_info)``:
       - Static method that imports the provided weights and encodings into the layer, enables
         quantization, and returns the finalized layer.

    Subclasses typically define:
    - A mapping of weight keys from the HN dictionary to Saitama tensor names (e.g., ``weight_keys``).
    - A mapping of encoding keys from the HN dictionary to Saitama tensor names (e.g., ``encoding_keys``).

    Steps to Add a New Translator
    -----------------------------
    1. Inherit from ``HnTRanslator``.
    2. Implement ``_build_layer_from_hn(hn_elemnt)`` to return a new layer.
    3. Implement (if needed) ``_build_weights_and_encodings(acceleras_params)`` to parse weights and
       encodings.
    4. Decorate your class with ``@HN_TRANSLATION_REGISTRY(LayerType.<YOUR_LAYER_TYPE>)``.
    """

    state_dict_translation: Dict[LayerMode, List[Tuple[str, KeyHandler]]] = {
        LayerMode.QUANT: [],
        LayerMode.NATIVE: [],
    }

    def _build_layer_from_hn(
        self,
        hn_layer: dict,
        tags: Tags,
        layer_mode: LayerMode,
        *,
        dtype=None,
        device=None,
    ) -> Tuple[Union[SaitamaModule, nn.Module], Tags]:
        """Create and return a new quantized Saitama `SaitamaModule` from the provided HN element."""
        if layer_mode == LayerMode.NATIVE:
            return self._build_native_layer_from_hn(hn_layer, tags, dtype=dtype, device=device)
        elif layer_mode == LayerMode.QUANT:
            return self._build_quant_layer_from_hn(hn_layer, tags, dtype=dtype, device=device)
        else:
            raise ValueError(f"Unsupported layer mode: {layer_mode}")

    def _build_native_layer_from_hn(
        self,
        hn_elemnt,
        tags,
        *,
        dtype=None,
        device=None,
    ) -> Tuple[Union[SaitamaModule, nn.Module], Tags]:
        """Create and return a new native module from the provided HN element."""
        raise NotImplementedError(f"This method should be implemented by the subclass - {self.__class__.__name__}")

    def _build_quant_layer_from_hn(
        self,
        hn_elemnt,
        tags,
        *,
        dtype=None,
        device=None,
    ) -> Tuple[Union[SaitamaModule, nn.Module], Tags]:
        """Create and return a new quant module from the provided HN element."""
        raise NotImplementedError(f"This method should be implemented by the subclass - {self.__class__.__name__}")

    def _add_weights_and_encodings(
        self,
        layer: Union[SaitamaModule, nn.Module],
        acceleras_params: Dict[str, np.ndarray],
        state_dict_translation: Dict[str, KeyHandler],
        hn_element: dict,
        *,
        tags=Optional[Tags],
        dtype=None,
        device=None,
    ) -> Union[SaitamaModule, nn.Module]:
        """
        Build the weights and encoding information from the given acceleras parameters.
        and then add then to the layer given

        Returns:
            BaseLayer: The layer given with the weights and encoding information added
        """
        state_dict = self.utils.create_state_dict(
            acceleras_params,
            state_dict_translation,
            hn_element,
            tags=tags,
            device=device,
            dtype=dtype,
        )
        layer = self.utils.validate_broadcast_and_load(layer, state_dict)
        return layer

    def translate(
        self,
        lname: str,
        hn_elemnt: dict,
        acceleras_params: dict,
        *,
        layer_mode: LayerMode = LayerMode.AUTO,
        dtype=None,
        device=None,
    ) -> SaitamaModule:
        self.logger.info(f"Translating layer {lname}")
        tags = set()
        try:
            layer_mode = self._resolve_layer_mode(layer_mode, hn_elemnt)
            layer, tags = self._build_layer_from_hn(hn_elemnt, tags, layer_mode, dtype=dtype, device=device)
            translation_dict = self.state_dict_translation[layer_mode]
            layer = self._add_weights_and_encodings(
                layer,
                acceleras_params,
                translation_dict,
                hn_elemnt,
                dtype=dtype,
                device=device,
                tags=tags,
            )
            layer.original_name = lname
        except Exception as e:
            e.args = (
                f"Error while translating layer {lname}, translator {self.__class__.__name__}, "
                f"mode: {layer_mode} \n {e}",
            )
            raise
        return layer

    @staticmethod
    def _quantization_params_to_precision_config(quantization_params: dict):
        bias_mode = quantization_params["bias_mode"]
        precision_mode = quantization_params["precision_mode"]

        b_pm = SplittedPrecisionMode.from_precision_mode(PrecisionMode(precision_mode))
        input_qtype = QType(min(b_pm.input, 15), False)
        if b_pm.weights == 16:
            weight_qtype = QType(15, False)
        else:
            weight_qtype = QType(b_pm.weights, True)

        # TODO Any chance this will be deprecated? having input 8bit but because we soffer of overflow we want to have 32bit accumulator
        accumulator_qtype = QType(b_pm.input * 2, True)
        # TODO we might want to check if it is signed or not
        output_qtype = QType(min(b_pm.output, 15), quantization_params["signed_output"])

        mac_cfg = MACPrecisionConfig(
            input_qtype=input_qtype,
            weight_qtype=weight_qtype,
            accumulator_qtype=accumulator_qtype,
            bias_mode=bias_mode,
            quantization_groups=quantization_params["quantization_groups"],
        )

        apu_cfg = APUPrecisionConfig(
            accumulator_qtype=accumulator_qtype,
            output_qtype=output_qtype,
            quantization_groups=quantization_params["quantization_groups"],
        )

        return mac_cfg, apu_cfg

    def _resolve_layer_mode(self, layer_mode: LayerMode, hn_layer: dict) -> LayerMode:  # TODO change to static
        if layer_mode is LayerMode.AUTO:
            return LayerMode.QUANT if "quantization_params" in hn_layer else LayerMode.NATIVE
        return layer_mode
