#!/usr/bin/env python

import json
from enum import Enum

from contextlib2 import contextmanager

from hailo_sdk_common.exceptions.exceptions import CommonModelException


class InferenceTargetException(CommonModelException):
    """Raised when an error related to the inference target has occurred."""


class NumericTargetException(Exception):
    pass


class InferenceTargets:
    """
    Enum-like class with all inference targets supported by the Hailo SDK.
    See the classes themselves for details about each target.
    """

    UNINITIALIZED = "uninitialized"
    SDK_NATIVE = "sdk_native"
    SDK_FP_OPTIMIZED = "sdk_fp_optimized"
    SDK_NUMERIC = "sdk_numeric"
    SDK_DEBUG_PRECISE_NUMERIC = "sdk_debug_precise_numeric"
    SDK_PARTIAL_NUMERIC = "sdk_partial_numeric"
    SDK_FINE_TUNE = "sdk_fine_tune"
    SDK_MIXED = "sdk_mixed"
    HW_SIMULATION = "hw_sim"
    HW_SIMULATION_MULTI_CLUSTER = "hw_sim_mc"
    HW_SIMULATION_FULL_CHIP = "hw_sim_fc"
    FPGA = "fpga"
    UDP_CONTROLLER = "udp"
    PCIE_CONTROLLER = "pcie"
    HW_DRY = "hw_dry"
    HW_DRY_UPLOAD = "hw_dry_upload"
    UV_WORKER = "uv"
    DANNOX = "dannox"
    ONNXRT = "ONNXRT"
    RUBINATOR = "rubinator"


class InferenceDebugTargets(Enum):
    """
    Enum-like class with all debugging options supported by the Hailo SDK.

    .. note:: Only the ``NO_DEBUG`` option is currently available for users.
    """

    NO_DEBUG = "no_debug"
    CLIENT_DEBUGGER = "client_debugger"
    TRACE = "trace"
    FULL = "full"


class EmulationInferenceTargets:
    """
    Enum-like class with all emulation inference targets supported by the Hailo SDK.
    See the classes themselves for details about each target.
    """

    UNINITIALIZED = "uninitialized"
    SDK_NATIVE = "sdk_native"
    SDK_FP_OPTIMIZED = "sdk_fp_optimized"
    SDK_NUMERIC = "sdk_numeric"
    SDK_DEBUG_PRECISE_NUMERIC = "sdk_debug_precise_numeric"
    SDK_PARTIAL_NUMERIC = "sdk_partial_numeric"
    SDK_FINE_TUNE = "sdk_fine_tune"
    SDK_MIXED = "sdk_mixed"
    SDK_QUANTIZED = "sdk_quantized"


class ParamsKinds:
    """Enum-like class for kinds of model parameters."""

    #: Original model parameters, usually floating point 32 bit.
    NATIVE = "native"

    #: Model parameters after batch normalization fusing into the layers' weights, but before
    #: quantization. When loading native parameters, the SDK automatically generates this kind
    #: of parameters.
    NATIVE_FUSED_BN = "native_fused_bn"

    #: Translated model parameters (quantized to 8 bit integer).
    TRANSLATED = "translated"

    #: Native after model modification
    FP_OPTIMIZED = "fp_optimized"

    #: Native after model optimization
    HAILO_OPTIMIZED = "hailo_optimized"

    #: Statistics parameters about the model
    STATISTICS = "statistics"


class EmulationObject:
    """
    A software based inference target.

    Note:
        This class should not be used directly. Use only its inherited classes.

    """

    NAME = InferenceTargets.UNINITIALIZED
    IS_NUMERIC = False
    IS_HARDWARE = False
    IS_SIMULATION = False

    def __init__(self, hw_arch=None):
        """
        Inference object constructor.

        Args:
            hw_arch (str, optional): Name of the hardware architecture. Defaults to ``None``.

        """
        self._is_device_used = False
        self._hw_arch = hw_arch

    def __new__(cls, *args, **kwargs):
        if cls.NAME == InferenceTargets.UNINITIALIZED:
            raise InferenceTargetException(f"{cls.__name__} is an abstract target and cannot be used directly.")
        # object's __new__() takes no parameters
        return super(type(cls), cls).__new__(cls)

    def __eq__(self, other):
        return other == type(self).NAME

    # TODO: Required for Python2 BW compatibility (SDK-10038)
    # This impl' comes by default in Python3
    def __ne__(self, other):
        return not self.__eq__(other)

    def __enter__(self):
        """
        Doesn't do anything, used for compatibility with the HW objects.
        """
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """
        Doesn't do anything, used for compatibility with the HW objects.
        """

    @contextmanager
    def use_device(self, *args, **kwargs):
        """A context manager that should wrap any usage of the target."""
        self._is_device_used = True
        yield
        self._is_device_used = False

    @property
    def name(self):
        """
        str: The name of this target. Valid values are defined by
        :class:`InferenceObject <hailo_sdk_common.targets.inference_targets.InferenceTargets>`.
        """
        return type(self).NAME

    @property
    def is_numeric(self):
        """
        bool: Determines whether this target is working in numeric mode.
        """
        return type(self).IS_NUMERIC

    @property
    def is_hardware(self):
        """
        bool: Determines whether this target runs on a physical hardware device.
        """
        return type(self).IS_HARDWARE

    @property
    def is_simulation(self):
        """
        bool: Determines whether this target is used for hardware simulation.
        """
        return type(self).IS_SIMULATION

    def _get_json_dict(self):
        return {
            "name": self.name,
            "is_numeric": self.is_numeric,
            "is_hardware": self.is_hardware,
            "is_simulation": self.is_simulation,
        }

    def to_json(self):
        """
        Get a JSON representation of this object.

        Returns
            str: A JSON dump.

        """
        return json.dumps(self._get_json_dict())


class SdkNative(EmulationObject):
    """
    Native emulation inference target. It runs the model as is without any hardware related
    changes. You can use it to make sure your model has been converted properly into the Hailo
    representation (HN). In addition, this target useful when you extract statistics for weights
    quantization.
    """

    NAME = EmulationInferenceTargets.SDK_NATIVE


class SdkFPOptimized(EmulationObject):
    """
    Native emulation inference target. It runs the model as :class:`SdkNative`, but includes all full-precision
    optimizations on the params if layers had one of them set via a quantization script.
    This target is used during the optimization and quantization process of a model, and can also be used
    to analyze the optimization.
    """

    NAME = EmulationInferenceTargets.SDK_FP_OPTIMIZED

    def __init__(self, hw_arch=None, enable_clipping=False):
        super().__init__(hw_arch=hw_arch)
        self._enable_clipping = enable_clipping

    @property
    def enable_clipping(self):
        return self._enable_clipping


class SdkNumeric(EmulationObject):
    """
    Numeric emulation inference target. Use this target when you want to get results that are bit
    exact to the Hailo hardware output without running on an actual device.

    .. warning:: Runner `infer` API will deprecate this target in future versions (August 2023).
        To run emulation inference on the quantized model, please use
        (:class:`~hailo_sdk_common.targets.inference_targets.SdkQuantized`) target instead.

    """

    NAME = EmulationInferenceTargets.SDK_NUMERIC
    IS_NUMERIC = True


class SdkDebugPreciseNumeric(EmulationObject):
    """
    This target runs the numeric emulation in a special debug mode.

    .. note:: This target is currently unavailable for users.
    """

    NAME = EmulationInferenceTargets.SDK_DEBUG_PRECISE_NUMERIC


class SdkPartialNumeric(EmulationObject):
    """
    Fast numeric emulation target. This target is not hardware bit exact, but it's `hardware
    like` and it runs much faster by using some of the original Tensorflow CPU/GPU layers'
    implementations. This target is useful when researching differences between the original model
    and the quantized model over large datasets without the actual Hailo hardware device.

    .. warning:: Runner `infer` API will deprecate this target in future versions (August 2023).
        To run emulation inference on the quantized model, please use
        (:class:`~hailo_sdk_common.targets.inference_targets.SdkQuantized`) target instead.

    """

    NAME = EmulationInferenceTargets.SDK_PARTIAL_NUMERIC
    IS_NUMERIC = True


class SdkQuantized(EmulationObject):
    """
    Quantized model emulation target. This target is not hardware bit exact, but it's `hardware
    like` and it runs much faster by using some of the original Tensorflow CPU/GPU layers'
    implementations. This target is useful when researching differences between the original model
    and the quantized model over large datasets without the actual Hailo hardware device.
    """

    NAME = EmulationInferenceTargets.SDK_QUANTIZED
    IS_NUMERIC = True


class SdkFineTune(EmulationObject):
    """
    Fine tuning target. You can use this mode to train (or fine tune) the model's weights and
    biases in a quantization aware manner. Fake quantization is used to allow back propagation.
    """

    NAME = EmulationInferenceTargets.SDK_FINE_TUNE

    def __init__(self):
        """Fine tune inference object constructor."""
        super().__init__()
        self._fine_tune_params = FineTuneParams()

    @property
    def fine_tune_params(self):
        """
        :class:`FineTuneParams <hailo_sdk_common.targets.inference_targets.FineTuneParams>`: The
        current params of this inference object.
        """
        return self._fine_tune_params

    def set_fine_tune_params(
        self,
        should_quantize_weights=False,
        should_relax_weights=False,
        should_quantize_activations=False,
    ):
        """
        Set fine tune params.

        See Also
            The documentation of :class:`FineTuneParams <hailo_sdk_common.targets.inference_targets.FineTuneParams>`
            contains additional details.

        """
        self._fine_tune_params.should_quantize_weights = should_quantize_weights
        self._fine_tune_params.should_relax_weights = should_relax_weights
        self._fine_tune_params.should_quantize_activations = should_quantize_activations

    def _get_json_dict(self):
        json_dict = super()._get_json_dict().copy()
        json_dict["fine_tune_params"] = self._fine_tune_params.to_json()
        return json_dict


class SdkMixed(EmulationObject):
    """
    Mixed emulation target. Some layers will be emulated in `native` mode, while the rest will be
    emulated in `numeric` mode. This target is useful when researching the effect that quantization
    of specific layers has on the accuracy of the whole model.
    """

    NAME = EmulationInferenceTargets.SDK_MIXED

    def __init__(self):
        """SdkMixed inference object constructor."""
        super().__init__()
        self._mixed_params = SdkMixedParams()

    @property
    def mixed_params(self):
        """
        :class:`SdkMixedParams <hailo_sdk_common.targets.inference_targets.SdkMixedParams>`: The
        current params of this inference object.
        """
        return self._mixed_params

    def set_mixed_params(self, numeric_target=EmulationInferenceTargets.SDK_NUMERIC):
        """
        Set mixed params.

        See Also
            The documentation of :class:`SdkMixedParams <hailo_sdk_common.targets.inference_targets.SdkMixedParams>`
            contains additional details.

        """
        if numeric_target not in [EmulationInferenceTargets.SDK_NUMERIC, EmulationInferenceTargets.SDK_PARTIAL_NUMERIC]:
            raise NumericTargetException("Error: SdkMixed works with either SdkNumeric or SdkPartialNumeric emulation")
        self._mixed_params.numeric_target = numeric_target

    def _get_json_dict(self):
        json_dict = super()._get_json_dict().copy()
        json_dict["mixed_params"] = self._mixed_params.to_json()
        return json_dict


class SdkMixedParams:
    """
    Parameters for :class:`SdkMixed <hailo_sdk_common.targets.inference_targets.SdkMixed>`
    target.
    """

    DEFAULT_NUMERIC_TARGET = EmulationInferenceTargets.SDK_NUMERIC

    def __init__(self, numeric_target=DEFAULT_NUMERIC_TARGET):
        """
        SdkMixed params constructor.

        Args:
            numeric_target (:class:`EmulationObject <hailo_sdk_common.targets.inference_targets.EmulationObject>`, optional):
            Which numeric emulation target to set for non-native layers.

        """
        if numeric_target not in [EmulationInferenceTargets.SDK_NUMERIC, EmulationInferenceTargets.SDK_PARTIAL_NUMERIC]:
            raise NumericTargetException("Error: SdkMixed works with either SdkNumeric or SdkPartialNumeric emulation")
        self.numeric_target = numeric_target

    @classmethod
    def from_json(cls, json_str):
        """
        Construct this class from previously exported JSON data.

        Args:
            json_str (str): The input JSON data.

        Returns:
            :class:`SdkMixedParams <hailo_sdk_common.targets.inference_targets.SdkMixedParams>`:
            The object constructed from the JSON data.

        """
        parsed_json = json.loads(json_str)
        return cls(parsed_json["numeric_target"])

    def to_json(self):
        """
        Get a JSON representation of this object.

        Returns
            str: A JSON dump.

        """
        return json.dumps({"numeric_target": self.numeric_target})


class FineTuneParams:
    """
    Parameters for :class:`SdkFineTune <hailo_sdk_common.targets.inference_targets.SdkFineTune>`
    target.
    """

    DEFAULT_QUANTIZE_WEIGHTS = False
    DEFAULT_QUANTIZE_ACTIVATIONS = True
    DEFAULT_RELAX_WEIGHTS = False

    def __init__(
        self,
        should_quantize_weights=DEFAULT_QUANTIZE_WEIGHTS,
        should_relax_weights=DEFAULT_RELAX_WEIGHTS,
        should_quantize_activations=DEFAULT_QUANTIZE_ACTIVATIONS,
    ):
        """
        Fine tune params constructor.

        Args:
            should_quantize_weights (bool, optional): Indicates whether the weights should be
                quantized using fake quantization. A new trainable variable named ``kernel_delta``
                is added to the graph when this option is turned on.
            should_relax_weights (bool, optional): EXPERIMENTAL. If True, use gradual ("relaxed")
                quantization for weights fine-tuning instead of STE/fake-quant, exporting the
                weights' "distance from grid"  tensor so that the client can penalize it in the loss
                function, slowly driving weights towards grid. Note that ``should_quantize_weights``
                should still be True to use this mode.
            should_quantize_activations (bool, optional): Indicates whether the activation should be
                quantized using fake quantization. A new trainable variable named ``fine_tune_bias``
                is added to the graph when this option is turned on.

        """
        self.should_quantize_weights = should_quantize_weights
        self.should_quantize_activations = should_quantize_activations
        self.should_relax_weights = should_relax_weights

    @classmethod
    def from_json(cls, json_str):
        """
        Construct this class from previously exported JSON data.

        Args:
            json_str (str): The input JSON data.

        Returns:
            :class:`FineTuneParams <hailo_sdk_common.targets.inference_targets.FineTuneParams>`:
            The object constructed from the JSON data.

        """
        parsed_json = json.loads(json_str)
        return cls(
            parsed_json["should_quantize_weights"],
            parsed_json["should_relax_weights"],
            parsed_json["should_quantize_activations"],
        )

    def to_json(self):
        """
        Get a JSON representation of this object.

        Returns
            str: A JSON dump.

        """
        return json.dumps(
            {
                "should_quantize_weights": self.should_quantize_weights,
                "should_relax_weights": self.should_relax_weights,
                "should_quantize_activations": self.should_quantize_activations,
            },
        )
