import atexit
import multiprocessing
import os
import shutil
import signal
import tempfile
import traceback
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import wraps
from pathlib import Path
from typing import Generic, Optional, Set, TypeVar

from pydantic.v1 import BaseModel

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DistributionStrategy,
    GPUAvailabilityMode,
    ThreeWayPolicy,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    SubprocessTracebackFailure,
    SubprocessUnexpectedFailure,
)
from hailo_model_optimization.acceleras.utils.distributed_utils import (
    DistContextInfo,
    get_strategy,
    gpu_distributed_context,
)
from hailo_model_optimization.acceleras.utils.tf_utils import get_gpu_availability_mode

# A module-level variable to store the path of our single temp directory.
_WORKING_TEMP_DIR = None


def _cleanup_temp_dir():
    """Called automatically at Python exit to remove the temp directory."""
    global _WORKING_TEMP_DIR
    if _WORKING_TEMP_DIR and _WORKING_TEMP_DIR.exists():
        shutil.rmtree(_WORKING_TEMP_DIR, ignore_errors=True)
        _WORKING_TEMP_DIR = None


def get_working_temp_path() -> Path:
    """
    Returns a Path object to a single shared temporary directory.
    The directory is created on first call, and automatically
    cleaned up when the program ends (unless killed abruptly).
    """
    global _WORKING_TEMP_DIR

    # If we haven't created the directory yet, do it now.
    if _WORKING_TEMP_DIR is None:
        _WORKING_TEMP_DIR = Path(tempfile.mkdtemp(prefix="acceleras_tmp_"))
        # Register the cleanup only once
        atexit.register(_cleanup_temp_dir)

    return _WORKING_TEMP_DIR


@dataclass
class Message:
    memento = None
    graph = None


def subprocess_wrapper(
    default_gpu_policy: DistributionStrategy = DistributionStrategy.SINGLE,
    supported_gpu_policy: Optional[Set[DistributionStrategy]] = None,
):
    """
    Decorator function for subprocess logic.
    Subprocess is created only if GPU is free (NOT_IN_USE).
    """

    def decorator(func):
        @wraps(func)
        def parent_wrapper(self, *args, **kwargs):
            self: BaseSubprocessFlow
            temp_dir_base = get_working_temp_path()
            gpu_info = get_gpu_availability_mode(self.flow_policies.multiproc_policy)
            gpu_strategy = get_strategy(
                gpu_info,
                self.flow_policies.gpu_policy,
                default_gpu_policy,
                supported_gpu_info=supported_gpu_policy,
                current_context_info=self.dist_info,
            )
            self.supervisor.tf_safe = False
            if gpu_info.gpu_availability is GPUAvailabilityMode.NOT_IN_USE:
                keyboard_int = signal.getsignal(signal.SIGINT)

                def child_wrapper(self, temp_dir_queue, conn, temp_dir_base, *args, **kwargs):
                    signal.signal(signal.SIGINT, keyboard_int)
                    temp_dir = None
                    self.supervisor.tf_safe = False
                    try:
                        with gpu_distributed_context(gpu_strategy) as dist_info:
                            self.dist_info = dist_info
                            self.reset_subprocess()
                            func(self, *args, **kwargs)
                            temp_dir = tempfile.mkdtemp(dir=temp_dir_base)
                            memento = self.save_state(temp_dir)
                            msg = Message()
                            msg.memento = memento
                            msg.graph = self.supervisor.dump()
                            temp_dir_queue.put(msg)
                    except Exception as inner_err:  # pylint: disable=bare-except
                        if temp_dir is not None:
                            shutil.rmtree(temp_dir, ignore_errors=True)
                        exception_dict = {"exception": inner_err, "traceback": traceback.format_exc()}
                        conn.send(exception_dict)  # Send structured exception data
                    finally:
                        conn.send(None)

                parent_conn, child_conn = multiprocessing.Pipe()
                temp_dir_queue = multiprocessing.Queue()
                proc = multiprocessing.Process(
                    target=child_wrapper,
                    args=(self, temp_dir_queue, child_conn, temp_dir_base, *args),
                    kwargs=kwargs,
                )
                proc.start()
                signal.signal(signal.SIGINT, signal.SIG_IGN)
                child_messages = []
                while proc.is_alive() or parent_conn.poll(0):
                    if parent_conn.poll(timeout=5):
                        msg = parent_conn.recv()
                        if msg is None:
                            break
                        child_messages.append(msg)
                proc.join()
                signal.signal(signal.SIGINT, keyboard_int)
                if len(child_messages) == 0 and proc.exitcode == 0:
                    msg: Message = temp_dir_queue.get()
                    self.supervisor.tf_safe = True
                    self.supervisor.load_state(msg.graph)
                    self.load_state(msg.memento, tf_safe=True)
                elif len(child_messages) > 0 and proc.exitcode == 0:
                    raise SubprocessTracebackFailure(*child_messages)
                else:
                    raise SubprocessUnexpectedFailure(
                        f"Subprocess {func.__name__} failed with unexpected error. exitcode {proc.exitcode}",
                    )

            else:
                # Force the rebuild between Steps:
                force_rebuild = os.environ.get("HAILO_FORCE_STEP_REBUILD") == "True"
                with gpu_distributed_context(gpu_strategy) as dist_info:
                    self.supervisor.tf_safe = False
                    self.dist_info = dist_info
                    self.build_model()
                    func(self, *args, **kwargs)

                    # This flag should be present in Develop
                    if force_rebuild:
                        temp_dir = tempfile.mkdtemp(dir=temp_dir_base)
                        memento = self.save_state(temp_dir)
                        self.load_state(memento, tf_safe=True)

        return parent_wrapper

    return decorator


Memento = TypeVar("Memento")


class Supervisor:
    tf_safe = False

    def dump(self) -> Memento: ...
    def load_state(self, memento: Memento): ...


class SupProcessPolicies(BaseModel):
    """Place to add more policies that will influence sup-process tasks"""

    multiproc_policy: ThreeWayPolicy
    gpu_policy: DistributionStrategy


class BaseSubprocessFlow(ABC, Generic[Memento]):
    """
    Generic class to define flows that can run in a subprocess.
    This class provides an abstract framework for managing subprocesses
    in computational flows, particularly where the
    state management across process boundaries is crucial.

    Attributes:
        dist_info (DistContextInfo): Contextual information relevant to process distribution.

    Note:
        This class assumes that when a subprocess is forked, the flow object
        will continue operation in the fork. Upon joining, the flow state is reloaded
        to synchronize with any state changes that occurred during subprocess execution.
    """

    dist_info: DistContextInfo = DistContextInfo()
    call_history: Optional[BaseModel] = None
    temp_state = None
    supervisor = Supervisor()

    @property
    def flow_policies(self) -> SupProcessPolicies:
        """Provides the default flow policies for subprocess management.

        Returns:
            SupProcessPolicies: The policies defining multiprocessing and GPU usage.
        """
        policies = SupProcessPolicies(multiproc_policy=ThreeWayPolicy.allowed, gpu_policy=DistributionStrategy.AUTO)
        return policies

    @abstractmethod
    def reset_subprocess(self):
        """Resets the state of the subprocess to ensure clean restarts. Must be implemented by subclasses."""

    @abstractmethod
    def save_state(self, path) -> Memento:
        """Saves the current state of the flow to a file.

        Args:
            path (str): The dit path where the state should be saved.

        Returns:
            Memento: A memento object representing the saved state.
        """

    @abstractmethod
    def load_state(self, memento: Memento, tf_safe: bool = False):
        """Loads the flow state from a memento object.

        Args:
            memento (Memento): The memento object containing the saved state.
            tf_safe (bool): if when loading will need to take care of especial situations for TF
        """

    @abstractmethod
    def build_model(self, force=False):
        """Constructs or rebuilds the model associated with the flow.

        Args:
            force (bool): If True, forces a rebuild of the model regardless of existing state.
        """
        pass

    def run(self, *, memento: Optional[BaseModel] = None, run_until: Optional[str] = ""): ...

    def stop(self):
        """Reserve key to stop a run"""
