from __future__ import annotations

import logging
from enum import Enum
from logging import Logger

from hailo_model_optimization.acceleras.model.hailo_model.hailo_model import (
    HailoModel,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasUnsupportedError
from hailo_model_optimization.acceleras.utils.flow_state.model_state_handler import (
    ModelStateHandler,
)
from hailo_model_optimization.acceleras.utils.flow_state_utils import (
    ModelState,
)


class ElementState(Enum):
    LOSSY = "LOSSY"
    NATIVE = "NATIVE"
    LOSSLESS = "LOSSLESS"


class FlowStateUpdater:
    """
    Performs update to configure the model flow_state. Update the model directly

    Usage:
    ```python
      flow_state_updater = FlowStateUpdater(model, logger= self.logger)
      flow_state_updater.load_params(commands=commands_dict).run()
    ```
    Commands can either be either be given by a substitute full dictionary, or by parial one.
    A partial commands is a dict, using glob syntax
    format:
        option 1:
    ```YAML
            full_name:
                field: 'value'
    ```
        option 2:
    ```YAML
            fullname: 'full_name'
            field: 'value'
    ```

    examples:
        example 1:
    ```YAML
        '*/activation4/*':
            internal_decoding_enabled: false
            internal_encoding_enabled: false
    ```

        example2:
    ```YAML
        full_name: */activation4/*
        internal_decoding_enabled: false
        internal_encoding_enabled: false
    ```
        example3:
    ```YAML
        */activation4/*:
            dict_kwgs:
                bits: 4
    ```
    """

    MACROS_KEYS = ("target",)  # To be extended with future macros
    ENCODING_KEYS = (
        "internal_encoding_enabled",
        "internal_decoding_enabled",
        "quant_inputs_enabled",
        "enforce_internal_encoding_in_call",
    )

    def __init__(
        self,
        model: HailoModel,
        logger: Logger = None,
        *,
        enfore_validity: bool = True,
    ):
        self._logger = logger
        self._model = model
        self._model_state_handler: ModelStateHandler = ModelStateHandler(logger)
        self._model_state: ModelState = None
        self.commands: dict = None
        self.enfore_validity: bool = enfore_validity

    def load_params(self, commands: dict) -> FlowStateUpdater:
        self.commands = commands
        return self

    def run(self) -> HailoModel:
        if self._is_full_file():
            self._model.import_flow_state(ModelState.parse_obj(self.commands))
        elif self.enfore_validity:
            self._run_enforce_validity(self.commands)
        else:
            self._update_model(list(self._get_name_command_tuple(self.commands)))
        return self._model

    def _update_model(self, commands: list[tuple], logger_level=None):  # TODO update to handle lists
        if len(commands) == 0:
            return
        model_state = self._model.export_flow_state()
        new_state = self._model_state_handler.load_params(
            commands,
            model_state,
            logger_level,
        ).run()
        self._model.import_flow_state(new_state)

    def _is_full_file(self) -> bool:
        """
        cf. keys of ModelState to the  recieved keys to determined if the dict is a full model state
        """
        return ModelState.schema()["properties"].keys() == self.commands.keys()

    def _run_enforce_validity(self, commands: dict):
        """
        generated the order of commands to be exetecuted: first the normal commands, then auto corrections of internal encoding, and at the end explicit encoding commands
        """
        # Enable all encodings
        # self._model.enable_internal_encoding() # TODO validates that this is indeed unncessary

        # Seperate the commands and run all non-encodings
        command_list, internal_encoding_commands, macros_list = self._preprocess_comands(commands)
        self._update_model(command_list)

        # Get disable internal encoding commands
        generated_encoding_corrections = self._model.export_disable_internal_encoding()
        # Run disable encoding automatically generated commands and encoding related commands
        generated_encoding_corrections = list(self._get_name_command_tuple(generated_encoding_corrections))
        self._update_model(generated_encoding_corrections, logger_level=logging.DEBUG)
        self._update_model(internal_encoding_commands)

        self._run_macros_list(macros_list)
        self._model.enforce_internal_encoding()

    def _get_name_command_tuple(self, command: dict):
        """Check the dict to determine what format of the allowed changes described in load_parmas docs"""
        if "full_name" in command.keys():
            yield command.pop("full_name"), command
        else:
            for key in command.keys():
                yield key, command[key]

    def _preprocess_comands(self, command: dict) -> tuple[dict, dict]:
        command_list = []
        internal_encoding_commands = []
        macros_list = []
        for key, command in self._get_name_command_tuple(command):
            # Separate the command dictionary into two dictionaries
            encoding_keys_dict, macros_keys_dict, other_keys_dict = self._separate_dict(command)

            # If matching_keys_dict has any keys, add it to internal_encoding_commands
            if encoding_keys_dict:
                internal_encoding_commands.append((key, encoding_keys_dict))

            if macros_keys_dict:
                macros_list.append((key, macros_keys_dict))
            # If other_keys_dict has any keys, add it to command_list
            if other_keys_dict:
                command_list.append((key, other_keys_dict))

        return command_list, internal_encoding_commands, macros_list

    def _run_macros_list(self, macros_list: list):
        for layer, command in macros_list:
            for command_name in command.keys():
                self._run_single_macro(layer, command_name, command[command_name])

    def _run_single_macro(self, layer: str, command_name: str, command_value: str):
        if command_name == "target":  # To be Expanded with further macros
            self._run_flow_target_macro(command_value, layer)
        else:
            raise AccelerasUnsupportedError(f"Unsupported macro {command_name}")

    def _run_flow_target_macro(self, command_value: str, layer: str):
        if command_value == "NATIVE":  # To be changed to swtich case with python 3.10
            self.turn_elements_native([layer])
        elif command_value == "QUANTIZED":
            self.turn_elements_lossy([layer])
        else:
            raise AccelerasUnsupportedError(f"Unsupported optimization target {command_value}")

    def _separate_dict(self, input_dict: dict):
        """
        Seprates the command dict to command that are macros, commands that invloves internal_encodings and all the rests
        """
        if input_dict is None:
            return {}, {}
        macros_keys_dict = {}
        encoding_keys_dict = {}
        other_keys_dict = {}

        for key, value in input_dict.items():
            if key in self.ENCODING_KEYS:
                encoding_keys_dict[key] = value
            elif key in self.MACROS_KEYS:
                macros_keys_dict[key] = value

            else:
                other_keys_dict[key] = value

        return encoding_keys_dict, macros_keys_dict, other_keys_dict

    def _change_elements_state(self, glob_syntax_elements: list[str], state: ElementState):
        """
        Generic method to change element states, accepts glob syntax names
        """
        state_mapping = {
            ElementState.LOSSY: {
                "status": "NUMERIC",
                "is_lossless": False,
                "aops_dict_kwgs": {"is_lossless": False},
            },
            ElementState.NATIVE: {
                "status": "FULLY_NATIVE",
                "is_lossless": True,
            },
            ElementState.LOSSLESS: {
                "status": "NUMERIC",
                "is_lossless": True,
            },
        }
        if not isinstance(glob_syntax_elements, list):
            raise AccelerasUnsupportedError(
                f"Expected 'glob_syntax_elements' to be a list, got {type(glob_syntax_elements).__name__} instead."
            )

        change_list = [(name, state_mapping[state]) for name in glob_syntax_elements]
        self._update_model(change_list)
        self._model.enforce_internal_encoding()

    def turn_elements_lossy(self, glob_syntax_elements: list[str]):
        """
        Macro to turn elements lossy, accepts glob syntax names
        """
        self._change_elements_state(glob_syntax_elements, ElementState.LOSSY)

    def turn_elements_native(self, glob_syntax_elements: list[str]):
        """
        Macro to turn elements native, accepts glob syntax names
        """
        self._change_elements_state(glob_syntax_elements, ElementState.NATIVE)

    def turn_elements_lossless(self, glob_syntax_elements: list[str]):
        """
        Macro to turn elements lossless, accepts glob syntax names
        """
        self._change_elements_state(glob_syntax_elements, ElementState.LOSSLESS)
