"""This module contains enums used by several SDK APIs."""

from enum import Enum
from typing import Optional, Protocol

from hailo_model_optimization.acceleras.utils.acceleras_definitions import CalibrationDataType, DistributionStrategy
from hailo_sdk_common.hailo_nn.hn_definitions import NMSMetaArchitectures  # noqa: F401

SUPPORTED_HW_ARCHS = ["hailo8", "hailo8r", "hailo8l"]
NON_SUPPORTED_HW_ARCHS = ["hailo15h", "hailo15m", "hailo15l", "hailo10h", "hailo10p"]
PARTIALLY_SUPPORTED_HW_ARCHS = ["pluto", "hailo10h2", "mars"]
DEFAULT_HW_ARCH = "hailo8"


class JoinAction(Enum):
    """
    Special actions to perform when joining models.

    See Also
        The :func:`~hailo_sdk_client.runner.client_runner.ClientRunner.join` API uses this enum.

    """

    #: join the graphs without any connection between them.
    NONE = "none"

    #: Automatically detects inputs for both graphs and combines them into one. This only works when both
    #: networks have a single input of the same shape.
    AUTO_JOIN_INPUTS = "auto_join_inputs"

    #: Automatically detects the output of this model and the input of the other model, and connect them.
    #: Only works when this model has a single output, and the other model has a single input,
    #: of the same shape.
    AUTO_CHAIN_NETWORKS = "auto_chain_networks"

    #: Supply a custom dictionary ``join_action_info``, which specifies which nodes from this model
    #: need to be connected to which of the nodes in the other graph. If keys and values are
    #: inputs, we join the inputs. If keys are outputs, and values are inputs, we chain the networks
    #: as described in the dictionary.
    CUSTOM = "custom"

    def __str__(self):
        return self.value


class JoinOutputLayersOrder(Enum):
    """Enum-like class to determine the output order of a model after joining with another model."""

    #: First are the outputs of this model who remained outputs, then outputs of the other model.
    #: The order in each sub-list is equal to the original order.
    NEW_OUTPUTS_LAST = "new_outputs_last"

    #: First are the outputs of the other model, then outputs of this model who remained outputs.
    #: The order in each sub-list is equal to the original order.
    NEW_OUTPUTS_FIRST = "new_outputs_first"

    #: If the models are chained, the outputs of the other model are inserted, in their original order,
    #: to the output list of this model instead of the first output which is no longer an output.
    #: If the models are joined by inputs, the other model's outputs are added last.
    NEW_OUTPUTS_IN_PLACE = "new_outputs_in_place"


class NNFramework(Enum):
    """Enum-like class for different supported neural network frameworks."""

    #: Tensorflow Lite
    TENSORFLOW_LITE = "tflite"

    #: ONNX
    ONNX = "onnx"

    def __str__(self):
        if self.value == "tflite":
            return "TensorFlow Lite"
        elif self.value == "onnx":
            return "ONNX"
        else:
            return "Unknown"


NNFrameworks = {x.value: x for x in NNFramework}
DEFAULT_NN_FRAMEWORK = NNFramework.ONNX

InferenceDataType = CalibrationDataType


class States(str, Enum):
    """Enum-like class with all the :class:`~hailo_sdk_client.runner.client_runner.ClientRunner` states."""

    #: Uninitialized state when generating a new :class:`~hailo_sdk_client.runner.client_runner.ClientRunner`
    UNINITIALIZED = "uninitialized"

    #: :class:`~hailo_sdk_client.runner.client_runner.ClientRunner` state after setting the original model path (ONNX/TF model)
    ORIGINAL_MODEL = "original_model"

    #: :class:`~hailo_sdk_client.runner.client_runner.ClientRunner` state after parsing (calling the :func:`~hailo_sdk_client.runner.client_runner.ClientRunner.translate_onnx_model`/:func:`~hailo_sdk_client.runner.client_runner.ClientRunner.translate_tf_model` API)
    HAILO_MODEL = "hailo_model"

    #: :class:`~hailo_sdk_client.runner.client_runner.ClientRunner` state after calling the :func:`optimize_full_precision() <hailo_sdk_client.runner.client_runner.ClientRunner.optimize_full_precision>` API.
    #: This state includes all the full-precision optimization such as model modification commands.
    FP_OPTIMIZED_MODEL = "fp_optimized_model"

    #: :class:`~hailo_sdk_client.runner.client_runner.ClientRunner` state after calling the :func:`optimize() <hailo_sdk_client.runner.client_runner.ClientRunner.optimize>` API.
    #: This state includes quantized weights.
    QUANTIZED_MODEL = "quantized_model"

    #: :class:`~hailo_sdk_client.runner.client_runner.ClientRunner` state after calling, for example, the :func:`load_lora_weights() <hailo_sdk_client.runner.client_runner.ClientRunner.load_lora_weights>` API.
    #: This state includes layers (e.g. LoRA layers) with non-quantized weights, that were added as a fine-tune to a quantized base.
    QUANTIZED_BASE_MODEL = "quanzited_base_model"

    #: :class:`~hailo_sdk_client.runner.client_runner.ClientRunner` state after calling the :func:`optimize() <hailo_sdk_client.runner.client_runner.ClientRunner.optimize>` API and saving in compilation only mode.
    #: This state includes only the necessary information for compilation (for example quantized weights but not full-precision information).
    QUANTIZED_SLIM_MODEL = "quantized_slim_model"

    #: :class:`~hailo_sdk_client.runner.client_runner.ClientRunner` state after compilation (calling the :func:`compile() <hailo_sdk_client.runner.client_runner.ClientRunner.compile>` API).
    COMPILED_MODEL = "compiled_model"

    #: :class:`~hailo_sdk_client.runner.client_runner.ClientRunner` state after compilation of a quantized slim model (calling the :func:`compile() <hailo_sdk_client.runner.client_runner.ClientRunner.compile>` API).
    #: This state allows only evaluation (profiling, inference).
    COMPILED_SLIM_MODEL = "compiled_slim_model"

    def __str__(self):
        return self.value


class InferenceContext(Enum):
    """Enum-like class with all the possible inference contexts modes"""

    #: SDK_NATIVE context is for inference of the original model (without any modification).
    SDK_NATIVE = "sdk_native"

    #: SDK_FP_OPTIMIZED context includes all model modification in floating-point (such as normalization, nms, and so on).
    SDK_FP_OPTIMIZED = "sdk_fp_optimized"

    #: SDK_QUANTIZED context is for inference of the quantized model. Used to measure degradation caused by quantization.
    SDK_QUANTIZED = "sdk_quantized"

    #: SDK_HAILO_HW inference context to run on the Hailo-HW.
    SDK_HAILO_HW = "sdk_hailo_hw"

    #: SDK_BIT_EXACT (preview) bit exact emulation. Currently not all layers and mode are supported
    SDK_BIT_EXACT = "sdk_bit_exact"

    def __str__(self):
        return self.value


class ContextInfo(Protocol):
    """
    This protocol represents a Context Info object that encapsulates the values need for context Infer
    To create a Context Info Object need to run

    .. code-block::

        with runner.infer_context(*args) as ctx:
            ctx : ContextInfo

    """

    #: InferenceContext use for the infer API.
    infer_context: InferenceContext

    #: State of the context.
    open: bool

    #: SdkGraphExport Internal object used by the SDK.
    graph_export: None

    #: What will be the gpu distributed Policy
    gpu_policy: DistributionStrategy

    #: Lora adapter name
    lora_adapter_name: Optional[str] = None


class Dims(str, Enum):
    BATCH = "batch"
    STACK = "stack"
    CHANNELS = "channels"
    HEIGHT = "height"
    WIDTH = "width"
    GROUPS = "groups"
    HEADS = GROUPS
    DISPARITY = GROUPS

    def __str__(self):
        return self.value

    def __repr__(self):
        return self.value
