import time
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Optional, Tuple

from numpy.typing import ArrayLike
from pydantic.v1 import BaseModel, Field

from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_model import FeatureConfigBaseModel
from hailo_model_optimization.acceleras.utils.logger import default_logger
from hailo_model_optimization.acceleras.utils.modification_tracker.modification_tracker_utils import (
    AlgoModificationTracker,
)


class AlgorithmStatus(Enum):
    NOT_DEFINE = "NOT_DEFINE"
    NOT_RUN = "NOT_RUN"
    SKIPPED = "SKIPPED"
    SUCCESSFULLY_DONE = "SUCCESSFULLY_DONE"
    UNFINISHED = "UNFINISHED"
    FAILED = "FAILED"


class CheckPointPolicy(Enum):
    NOT_SUPPORTED = "NOT_SUPPORTED"
    SAVE_AT_START = "SAVE_AT_START"
    SAVE_MID_RUN = "SAVE_MID_RUN"


class AlgoResults(BaseModel):
    name: str = Field("", description="Name of the algorithm")
    status: AlgorithmStatus = Field(False, description="If all the stages of the algorithm run")
    skip_at_start: bool = Field(False, description="If For some reason the algo was skip")
    algorithm_time: float = Field(0, description="Number of second the algorithm took", ge=0)
    algo_memento: Optional[BaseModel] = Field(description="Information for an algorithm to store/restore his state")

    class Config:
        use_enum_values = True


class AlgorithmBase(ABC):
    """
    Base class for all the algorithms in Model Optimization
    This Gives an interface for all the Mo blocks
    and a generic use case.

    Args:
                model: Mutable, the algoritm may change the model.
                model_config: dict - Params needed for the block.
                name: the name of the algorithm.
                logger_level: Interger with standar logging levels.
                logger: the logger we use if needed.

    Example:
        >>> model = HailoModel()
        >>> model_config = ModelOptimizationConfig()
        >>> mo_algo = AlgorithmBase(model, model_config, "Algorith Base", 30)
        >>> mo_algo.run()

    """

    CHECKPOINT_POLICY = CheckPointPolicy.NOT_SUPPORTED

    def __init__(
        self,
        model: HailoModel,
        model_config: ModelOptimizationConfig,
        name: str,
        logger_level: int,
        logger=None,
        *args,
        **kl_kwargs,
    ):
        self._results = AlgoResults(name=name)

        self._logger_level = logger_level
        self._logger = logger or default_logger()
        self._model = model
        self._model_config = model_config
        self._name = name
        self._statistics = None
        self._modifications_meta_data = AlgoModificationTracker()

    @property
    def name(self) -> str:
        "read only property for the algorithm name"
        return self._name.lower().replace(" ", "_")

    @abstractmethod
    def get_algo_config(self) -> FeatureConfigBaseModel:
        """
        return the current algorithm configuration
        """

    @abstractmethod
    def _setup(self):
        """
        Validates the inputs, check if the data is ready to be use
        """

    @abstractmethod
    def _run_int(self):
        """
        Here the algorithm should work. do what it supposed to do
        this is an internal function which wrapped by the run function
        """

    @abstractmethod
    def should_skip_algo(self) -> bool:
        """
        Here we decide whether to skip the algorithm base on the algorithm configuration
        """

    def save(self) -> BaseModel:
        """Save its state if supported"""

    def restore(self, memento: BaseModel) -> None:
        """Load its state from a memento based on Pydantic BaseModel
        each algorithm should define its own memento if supported"""

    def export_statistics(self) -> Optional[Dict[str, ArrayLike]]:
        """
        return the statistics extracted during the algorithm run
        should be ether None or Dict[str, ArrayLike] where each key is in the following form:
        {network_name}/{layer_name}/{param_name}
        """

    def _parse_statistics(self) -> Optional[Dict[str, ArrayLike]]:
        """
        Change keys from self.export_statistics() to contain algorithm name. That is, changing each key from
        {network_name}/{layer_name}/{param_name} to {network_name}/{layer_name}/{algorithm_name}/{param_name}
        """
        statistics = self.export_statistics()
        if statistics is None:
            return None
        algorithm_name = self._name.lower().replace(" ", "_")
        new_statistics = dict()
        for key, value in statistics.items():
            network_name, layer_name, param_name = key.split("/", 2)
            new_key = "/".join([network_name, layer_name, algorithm_name, param_name])
            new_statistics[new_key] = value
        return new_statistics

    def run(self, *, memento: Optional[BaseModel] = None) -> Tuple[HailoModel]:
        if self.should_skip_algo():
            self._logger.log(self._logger_level, f"{self._name} skipped")
            self._results.status = AlgorithmStatus.SKIPPED
        else:
            self._logger.log(self._logger_level, f"Starting {self._name}")
            self.log_config()
            self._setup()
            try:
                if memento is not None and self.CHECKPOINT_POLICY == CheckPointPolicy.SAVE_MID_RUN:
                    self.restore(memento)
                start_time = time.time()
                self._run_int()
                self._results.status = AlgorithmStatus.SUCCESSFULLY_DONE
            except KeyboardInterrupt:
                self._results = self._gracefully_save_state(self._results)
            end_time = time.time()
            hours, rem = divmod(end_time - start_time, 3600)
            minutes, seconds = divmod(rem, 60)
            self._results.algorithm_time = end_time - start_time

            self._logger.log(
                self._logger_level,
                f"Model Optimization Algorithm {self._name} is done (completion time is {int(hours):0>2}:{int(minutes):0>2}:{seconds:05.2f})",
            )

        self._statistics = self._parse_statistics()
        return self._model

    def get_statistics(self) -> dict:
        stats = self._statistics
        if stats is None:
            return dict()
        return stats

    def get_results(self) -> AlgoResults:
        return self._results

    def get_modifications_meta_data(self) -> AlgoModificationTracker:
        return self._modifications_meta_data

    def log_config(self):
        config = self.get_algo_config()
        keys = config.keys()
        internal_params = config.internal_keys() - {"meta"}
        advanced_params = config.advanced_keys() if hasattr(config, "advanced_keys") else set()
        advanced_params = advanced_params - internal_params
        user_params = keys - internal_params - advanced_params
        self._logger.verbose(f"{self._name} configuration:")
        for key in user_params:
            self._logger.verbose(f"\t\t{key}: {getattr(config, key)}")
        if advanced_params:
            self._logger.verbose(f"{self._name} advanced configuration:")
            for key in advanced_params:
                self._logger.verbose(f"\t\t{key}: {getattr(config, key)}")

    def _gracefully_save_state(self, results: AlgoResults) -> AlgoResults:
        if self.CHECKPOINT_POLICY == CheckPointPolicy.NOT_SUPPORTED:
            results.status = AlgorithmStatus.FAILED
        elif self.CHECKPOINT_POLICY == CheckPointPolicy.SAVE_MID_RUN:
            memento = self.save
            results.algo_memento = memento
            results.status = AlgorithmStatus.UNFINISHED
        else:
            results.status = AlgorithmStatus.FAILED

        return results
