from __future__ import annotations

import fnmatch
import logging
from enum import Enum
from itertools import chain
from logging import Logger

from hailo_model_optimization.acceleras.utils.flow_state_utils import (
    AtomicOpState,
    BaseFlowState,
    LayerState,
    ModelState,
)


def yield_sub_elemnts(state: BaseFlowState):
    """
    Generate all the sub state of a given state from
    {
        ModelState,
        LayerState,
        AtomicOpState,
        LossyState,
    }
    """
    if isinstance(state, ModelState):
        for layer in state.layers.values():
            yield layer
    if isinstance(state, LayerState):
        for aop in state.atomic_ops.values():
            yield aop
    if isinstance(state, AtomicOpState):
        for lossy_element in chain(
            state.input_lossy_elements.values(),
            state.output_lossy_elements.values(),
            state.weight_lossy_elements.values(),
        ):
            yield lossy_element


class ModelStateHandler:
    """
    Flow Updater takes a ModelState object and perform updates on it, unaware of the model itself.
    Can be used to modify the model flow, e.g disabling and enabling certain lossy elements.
    To be used by FlowStateUpdater Object.
    args:
        logger
    """

    def __init__(self, logger: Logger) -> None:
        self._logger = logger
        self.commands: dict = None
        self.model_state: ModelState = None
        self._logger_level = None

    def load_params(
        self,
        commands: list[tuple],
        model_state: ModelState,
        logger_level,
    ) -> ModelStateHandler:
        """
        commands is partial dict, using glob syntax
        format:
            option 1:
        ```YAML
                full_name:
                    field: 'value'
        ```
            option 2:
        ```YAML
                fullname: 'full_name'
                field: 'value'
        ```

        examples:
            example 1:
        ```YAML
            '*/activation4/*':
                internal_decoding_enabled: false
                internal_encoding_enabled: false
        ```

            example2:
        ```YAML
            full_name: */activation4/*
            internal_decoding_enabled: false
            internal_encoding_enabled: false
        ```
            example3:
        ```YAML
            */activation4/*:
                dict_kwgs:
                    bits: 4
        ```
        """
        self.commands = commands
        self.model_state = model_state
        self._logger_level = logger_level if logger_level is not None else logging.INFO

        return self

    def run(self) -> ModelState:
        model_state = self.model_state.copy()  # Ensure that this is a new object
        model_state = self._update_all_commands(model_state)

        return model_state

    def _update_all_commands(self, model_state: ModelState) -> ModelState:
        for name, command in self.commands:
            self._perform_update(model_state, name, command)
        return model_state

    def _perform_update(self, sub_state: BaseFlowState, name: str, command: dict):
        if self._check_name_matches(sub_state, name):
            self._perform_update_for_matched_name(sub_state, name, command)
        else:
            for new_sub_state in yield_sub_elemnts(sub_state):
                self._perform_update(new_sub_state, name, command)

    def _perform_update_for_matched_name(self, sub_state: BaseFlowState, name: str, command: dict):
        for key, value in command.items():
            if hasattr(sub_state, key):
                self._set_attr(sub_state, key, value)
            else:
                for new_sub_state in yield_sub_elemnts(sub_state):
                    self._perform_update_for_matched_name(new_sub_state, name, command)

    def _set_attr(self, sub_state, key, value):
        """
        set attributes, handles dictionaries, enum, and primitve types
        """
        attr = getattr(sub_state, key)
        if isinstance(attr, dict):
            attr.update(value)
        elif isinstance(attr, Enum):  # Handle Enum case
            enum_class = type(attr)
            if value in enum_class.__members__:
                setattr(sub_state, key, enum_class[value])
            else:
                self._logger.warning(f"Invalid enum value '{value}' for '{key}' in '{sub_state.full_name}'")
        else:
            setattr(sub_state, key, value)
        self._logger.log(
            self._logger_level, f"changed '{sub_state.full_name}' in field '{key}' to value: '{getattr(sub_state,key)}'"
        )

    def _check_name_matches(self, sub_state: BaseFlowState, name: str):
        return fnmatch.fnmatch(sub_state.full_name, name)
