import contextlib
import os
import sys
from multiprocessing import Pipe, Process, Queue

import tensorflow as tf

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    GPUAvailabilityMode,
    GPUInfo,
    ThreeWayPolicy,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import SubprocessFailure
from hailo_model_optimization.acceleras.utils.dataset_util import data_to_dataset, get_dataset_length


def get_gpu_availability_mode(multiproc_policy="allowed") -> GPUInfo:
    # using tf.config.list_physical_devices('GPU') might return GPUs that can't be used. (can't be trusted)
    # tf.test.is_gpu_available() works fine but is considered deprecated API
    # device_lib is internal undocumented API, suggested by this thread some years ago:
    #   https://stackoverflow.com/questions/38559755/how-to-get-current-available-gpus-in-tensorflow
    #       from tensorflow.python.client import device_lib
    #       local_device_protos = device_lib.list_local_devices()
    #       return len([x.full_name for x in local_device_protos if x.device_type == 'GPU']) > 0
    # Since tf.test.is_gpu_available() is deprecated on tf 2.18, we create a tensor on the GPU.
    #   if the GPU is not available the process will crash which mean it can be used for child process sanity,
    #   but not for parent process.
    # using tf.config is the documented API and will work as long as the user doesn't fork the SDK by himself
    #   (and this is the way we checked GPU availability so far)
    # The following code can be used to detect gpu availability:
    #   with tf.device('GPU'):
    #       # If gpu is already in use, proc will crash with exitcode sigabrt
    #       tf.constant(1.0)
    # The code triggers SIGABRT if the GPU is in use. Replacing sigabrt handler is recommended (to prevent stack print)
    #   signal.signal(signal.SIGABRT, lambda: None)
    # In this case, the process exits with value (-signal.SIGABRT)
    # The code prints CUDA initialization in tensorflow, which can't be suppressed.
    def is_gpu_usable():
        gpus = tf.config.list_physical_devices("GPU")
        if not gpus:
            return False
        try:
            # Try to set memory growth (will fail if context is unusable)
            tf.config.experimental.set_memory_growth(gpus[0], True)
            # Try a trivial GPU allocation
            with tf.device("/GPU:0"):
                tf.constant(1.0)  # Try allocating something small
            return True
        except Exception:
            return False

    def get_gpu_info():
        num_gpus = len(tf.config.list_physical_devices("GPU"))

        if num_gpus == 0:
            gpu_info = GPUInfo(gpu_availability=GPUAvailabilityMode.NOT_AVAILABLE, num_gpus=num_gpus)
        elif is_gpu_usable():
            # if GPU is available but not in use, we can use it
            gpu_info = GPUInfo(gpu_availability=GPUAvailabilityMode.NOT_IN_USE, num_gpus=num_gpus)
        else:
            gpu_info = GPUInfo(gpu_availability=GPUAvailabilityMode.IN_USE, num_gpus=num_gpus)

        return gpu_info

    if os.environ.get("HAILO_DISABLE_MO_SUB_PROCESS", "0") == "1":
        gpu_info = get_gpu_info()
        if gpu_info.gpu_availability == GPUAvailabilityMode.NOT_IN_USE:
            # if gpu is not in use, trigger "in use" to prevent multi processing
            gpu_info.gpu_availability = GPUAvailabilityMode.IN_USE
        return gpu_info

    def child_proc(queue: Queue):
        with open(os.devnull, "w", encoding="utf-8") as devnull, contextlib.redirect_stdout(
            devnull
        ), contextlib.redirect_stderr(devnull):
            gpu_info = get_gpu_info()
            queue.put(gpu_info)
            sys.exit(0)

    queue = Queue()
    proc = Process(target=child_proc, args=(queue,))
    proc.start()
    proc.join()
    gpu_info = queue.get()

    if proc.exitcode != 0:
        raise SubprocessFailure(f"GPU availability check subprocess failed with exitcode {proc.exitcode}")

    multiproc_policy = ThreeWayPolicy(multiproc_policy)

    if multiproc_policy is ThreeWayPolicy.allowed:
        pass
    elif multiproc_policy is ThreeWayPolicy.disabled:
        gpu_info.gpu_availability = (
            GPUAvailabilityMode.IN_USE
            if gpu_info.gpu_availability is GPUAvailabilityMode.NOT_IN_USE
            else gpu_info.gpu_availability
        )

    elif multiproc_policy is ThreeWayPolicy.enabled:
        if gpu_info.gpu_availability is not GPUAvailabilityMode.NOT_IN_USE:
            raise ValueError("Can't force multiprocessing if GPU isn IN_USE")
    else:
        raise ValueError(f"multiproc_policy received unexpected value {multiproc_policy}")

    return gpu_info


def get_tf_dataset_length(data, data_type, threshold, gpu_state: GPUAvailabilityMode):
    def convert_and_get_length(data, data_type, threshold):
        dataset, _ = data_to_dataset(data, data_type)
        dataset_length = get_dataset_length(dataset, threshold=threshold)
        return dataset_length

    def child_proc(conn):
        try:
            dataset_length = convert_and_get_length(data, data_type, threshold)
            conn.send({"dataset_length": dataset_length, "error": None})
        except Exception:
            import traceback

            conn.send({"dataset_length": None, "error": traceback.format_exc()})
        finally:
            conn.close()

    if gpu_state == GPUAvailabilityMode.NOT_IN_USE:
        parent_conn, child_conn = Pipe()
        proc = Process(target=child_proc, args=(child_conn,))
        proc.start()
        proc.join()
        result = parent_conn.recv()
        if result.get("error"):
            raise SubprocessFailure(f"Dataset length check subprocess failed with error:\n{result['error']}")
        dataset_length = result["dataset_length"]

    else:  # gpu_state in use:
        # avoid using GPU in case of MO subprocess enabled
        # If GPU is in use, we must use the same process.
        # If GPU is not available, tf might've already worked on the CPU, and can't be forked
        dataset_length = convert_and_get_length(data, data_type, threshold)

    return dataset_length
