#!/usr/bin/env python

"""Represents how Hailo models are returned to the user for Tensorflow integration."""

from enum import IntEnum

from hailo_sdk_common.hailo_nn.hailo_nn import HailoNN


class InputTensorException(Exception):
    """Raised when there is a mismatch between the expected inputs and the inputs in practice."""


class ExportLevel(IntEnum):
    """
    Enum that describes the granularity of a model export -- which layers' tensors are
    included.
    """

    #: Exports a list of tensors from the output layers only.
    OUTPUT_LAYERS = 0

    #: Exports a list of tensors from output layers only, plus their rescale operations.
    OUTPUT_LAYERS_RESCALED = 1

    #: Exports a list of tensors from all layers of the model -- output layers and inner layers.
    ALL_LAYERS = 2

    #: Exports a list of tensors at all_layers level, plus their rescale operations.
    ALL_LAYERS_RESCALED = 3

    #: Exports a list of tensors before the activation of each layer.
    ALL_LAYERS_PRE_ACT_OPS = 4

    #: Exports a list of tensors at all_layers level, plus all inner operations including bias and
    #: pre-activation.
    ALL_LAYERS_ALL_OPS = 5

    #: Exports a list of all calibration statistics tensors from all layers of the model.
    CALIBRATION_STATS = 6

    #: Exports a list of histograms tensors, used in activation clipping prior to quantization.
    ACTIVATIONS_HISTOGRAMS = 7

    # FT Auxiliary tensors:
    FT_KERNEL_RANGE = 8
    FT_ALPHA = 9
    FT_FINAL_KERNEL = 10
    FT_KERNEL_FRAC_PART = 11

    FT_TRAIN_OUTPUTS = 12


class VariableExportLevel(IntEnum):
    #: Exports a list of biases from all layers of the model.
    BIASES = 0

    #: Exports a list of fine-tuned biases used for BFT algorithm.
    BIASES_DELTA = 1

    #: Exports a list of kernel variables from all layers of the model.
    KERNELS = 2

    #: Exports a list of fine-tuned kernel variables from all layers of the model.
    KERNELS_DELTA = 3

    #: Exports a list of newly created variables that need to be initialized when the model is returned to client.
    UNSET_VARIABLES = 4


class GraphExport:
    """Graph export block to re-use existing logic."""

    def __init__(self, network_groups, hef_infer_wrapper, output_types):
        self._network_groups = network_groups
        self._hef_infer_wrapper = hef_infer_wrapper
        self._output_types = output_types

    @property
    def network_groups(self):
        return self._network_groups

    @property
    def hef_infer_wrapper(self):
        return self._hef_infer_wrapper

    @property
    def output_types(self):
        return self._output_types


class HailoGraphExport:
    """Hailo Model export object."""

    def __init__(
        self,
        session,
        graph,
        input_tensors,
        init_output_exports=None,
        init_variables_exports=None,
        hef=None,
        network_groups=None,
        hef_infer_wrapper=None,
        load_params_func=None,
    ):
        """
        Constructor for Hailo graph export.

        Args:
            session (:obj:`tf.Session`): Tensorflow session of the returned graph.
            graph (:obj:`tf.Graph`): Tensorflow graph containing the model.
            input_tensors (dict): A dictionary mapping the model's input layers' names in the HN to
                the names of their input tensors.
            init_output_exports (dict): A dictionary of exports, where the keys are of type
                :class:`ExportLevel` and the values are exports of type
                :class:`OutputTensorsExport`.
            init_variables_exports (dict): A dictionary of exports, where the keys are of type
                :class:`VariableExportLevel` and the values are exports of type
                :class:`VariablesExport`.
            hef (bytes): HEF file data.
            network_groups (list): A list of network groups returned from target.configure.

        """
        self._session = session
        self._graph = graph
        self._input_tensors = input_tensors
        self._output_tensors_exports = {}
        self._variables_exports = {}
        self._rescale_output = False
        self._load_params_func = load_params_func

        if init_output_exports:
            self.output_tensors_exports = init_output_exports
        if init_variables_exports:
            self._variables_exports = init_variables_exports

        self._hef = hef
        self._network_groups = network_groups
        self._hef_infer_wrapper = hef_infer_wrapper

    @property
    def hef(self):
        """HEF Binary compiled model files that are loaded to the device."""
        return self._hef

    @property
    def input_tensor(self):
        """:obj:`tf.Tensor`: The input tensor of Hailo emulator/hardware graph."""
        # TODO: remove this method after multiple inputs integration
        if len(self._input_tensors) > 1:
            raise InputTensorException("Unsupported input_tensor property for a network with multiple inputs")
        return self._input_tensors[next(iter(self._input_tensors))]

    @property
    def input_tensors(self):
        """List of :obj:`tf.Tensor`: The input tensors of Hailo emulator/hardware graph."""
        return self._input_tensors

    @property
    def network_groups(self):
        return self._network_groups

    @property
    def hef_infer_wrapper(self):
        return self._hef_infer_wrapper

    @input_tensors.setter
    def input_tensors(self, value):
        self._input_tensors = value

    @property
    def output_export(self):
        if self.rescale_output:
            return self._output_tensors_exports[ExportLevel.OUTPUT_LAYERS_RESCALED]
        return self._output_tensors_exports[ExportLevel.OUTPUT_LAYERS]

    @property
    def output_tensors(self):
        """
        list of :obj:`tf.Tensor`: Output tensors export at the ``OUTPUT_LAYERS`` export level. If
        ``rescale_output`` is enabled, this function returns the rescaled version of the same tensor
        list.
        """
        return self.output_export.tensors

    @property
    def ft_train_output_tensors(self):
        return self._output_tensors_exports[ExportLevel.FT_TRAIN_OUTPUTS].tensors

    @property
    def all_layers(self):
        """
        list of :obj:`tf.Tensor`: All layers (outputs and inner layers) tensors list. If
        ``rescale_output`` is enabled, this function returns the rescaled version of the same tensor
        list.
        """
        if self.rescale_output:
            return self._output_tensors_exports[ExportLevel.ALL_LAYERS_RESCALED].tensors
        return self._output_tensors_exports[ExportLevel.ALL_LAYERS].tensors

    @property
    def all_layers_pre_act_ops(self):
        """
        list of :obj:`tf.Tensor`: Full graph tensors list -- including all inner ops of each
        layer prior to activation.
        """
        return self._output_tensors_exports[ExportLevel.ALL_LAYERS_PRE_ACT_OPS].tensors

    @property
    def all_layers_all_ops(self):
        """
        list of :obj:`tf.Tensor`: Full graph tensors list -- including all inner ops of each
        layer.
        """
        return self._output_tensors_exports[ExportLevel.ALL_LAYERS_ALL_OPS].tensors

    @property
    def calibration_stats(self):
        """List of :obj:`tf.Tensor`: Full graph stats tensors list -- used for model calibration."""
        return self._output_tensors_exports[ExportLevel.CALIBRATION_STATS].tensors

    @property
    def activations_histograms(self):
        """
        list of :obj:`tf.Tensor`: Activations histograms tensors list -- used for model
        calibration.
        """
        return self._output_tensors_exports[ExportLevel.ACTIVATIONS_HISTOGRAMS].tensors

    @property
    def activations_histograms_layers_names(self):
        """List of str: Corresponding layers' names list of all activations histograms tensors."""
        return self._output_tensors_exports[ExportLevel.ACTIVATIONS_HISTOGRAMS].layers_names

    @property
    def biases(self):
        """List of :obj:`tf.Variable`: Full graph bias variables list."""
        return self._variables_exports[VariableExportLevel.BIASES].variables

    @property
    def biases_layers_names(self):
        """
        list of str: Corresponding layers' names list of all bias variables (output and inner
        layers).
        """
        return self._variables_exports[VariableExportLevel.BIASES].layers_names

    @property
    def biases_delta(self):
        """
        list of :obj:`tf.Variable`: Full graph fine tune bias variables list -- used for BTF
        algorithm.
        """
        return self._variables_exports[VariableExportLevel.BIASES_DELTA].variables

    @property
    def biases_delta_layers_names(self):
        """
        list of str: Corresponding layers' names list of all BFT fine tune bias variables (output
        and inner layers).
        """
        return self._variables_exports[VariableExportLevel.BIASES_DELTA].layers_names

    @property
    def kernels(self):
        """List of :obj:`tf.Variable`: Full graph kernel variables list."""
        return self._variables_exports[VariableExportLevel.KERNELS].variables

    @property
    def kernels_layers_names(self):
        """
        list of str: Corresponding layers' names list of kernel variables list (output and inner
        layers).
        """
        return self._variables_exports[VariableExportLevel.KERNELS].layers_names

    @property
    def kernels_delta(self):
        """List of :obj:`tf.Variable`: Full graph fine-tuned kernel variables list."""
        return self._variables_exports[VariableExportLevel.KERNELS_DELTA].variables

    @property
    def kernels_delta_layers_names(self):
        """
        list of str: Corresponding layers' names list of fine-tuned kernel variables list (output
        and inner layers).
        """
        return self._variables_exports[VariableExportLevel.KERNELS_DELTA].layers_names

    @property
    def ft_kernel_range_tensors(self):
        """Dict of :obj:`tf.Tensor`: kernel range by layer name."""
        # return self._variables_exports[VariableExportLevel.FT_KERNEL_CLIP].variables
        level = ExportLevel.FT_KERNEL_RANGE
        return dict(
            zip(
                self._output_tensors_exports[level].layers_names,
                self._output_tensors_exports[level].tensors,
            )
        )

    @property
    def ft_alpha_tensors(self):
        """Dict of :obj:`tf.Tensor`: alpha-blend coefficient by layer name."""
        level = ExportLevel.FT_ALPHA
        return dict(
            zip(
                self._output_tensors_exports[level].layers_names,
                self._output_tensors_exports[level].tensors,
            )
        )

    @property
    def ft_final_kernel_tensors(self):
        """
        dict of :obj:`tf.Tensor`: eventual kernels as used in
        :class:`SdkFineTune <hailo_sdk_common.targets.inference_targets.SdkFineTune>` mode by layer
        name.
        """
        level = ExportLevel.FT_FINAL_KERNEL
        return dict(
            zip(
                self._output_tensors_exports[level].layers_names,
                self._output_tensors_exports[level].tensors,
            )
        )

    @property
    def ft_kern_frac_part_tensors(self):
        """
        dict of :obj:`tf.Tensor`: fractional part of kernels as used in
        :class:`SdkFineTune <hailo_sdk_common.targets.inference_targets.SdkFineTune>` mode by layer
        name.
        """
        level = ExportLevel.FT_KERNEL_FRAC_PART
        return dict(
            zip(
                self._output_tensors_exports[level].layers_names,
                self._output_tensors_exports[level].tensors,
            )
        )

    @property
    def unset_variables(self):
        """List of :obj:`tf.Variable`: un-initialized variables list."""
        return self._variables_exports[VariableExportLevel.UNSET_VARIABLES].variables

    @property
    def unset_variables_layers_names(self):
        """List of str: un-initialized variables list of layers' names."""
        return self._variables_exports[VariableExportLevel.UNSET_VARIABLES].layers_names

    @property
    def output_tensors_original_names(self):
        """
        list of list of str: Corresponding original layers' names list of the basic output tensors
        list.
        """
        return self._output_tensors_exports[ExportLevel.OUTPUT_LAYERS].original_names

    @property
    def all_layers_original_names(self):
        """
        list of list of str: Corresponding original layers' names list of all layers tensors list
        (output and inner layers).
        """
        return self._output_tensors_exports[ExportLevel.ALL_LAYERS].original_names

    @property
    def output_tensors_layers_names(self):
        """List of str: Corresponding layers' names list of the basic output tensors list."""
        return self._output_tensors_exports[ExportLevel.OUTPUT_LAYERS].layers_names

    @property
    def all_layers_names(self):
        """
        list of str: Corresponding layers' names list of all layers tensors list (output and inner
        layers).
        """
        return self._output_tensors_exports[ExportLevel.ALL_LAYERS].layers_names

    @property
    def graph(self):
        """:obj:`tf.Graph`: Tensorflow graph to which the new nodes were appended."""
        return self._graph

    @property
    def rescale_output(self):
        """
        bool: A flag for ``rescale_output``. If enabled,
        :func:`output_tensors <HailoGraphExport.output_tensors>` and
        :func:`all_layers <HailoGraphExport.all_layers>` properties will return the rescaled
        versions of their levels' tensors lists, respectively.
        """
        return self._rescale_output

    @rescale_output.setter
    def rescale_output(self, should_rescale):
        self._rescale_output = should_rescale

    @property
    def output_tensors_exports(self):
        """
        dict: A dictionary of exports where the keys are of type :class:`ExportLevel` and the
        values are exports of type :class:`OutputTensorsExport`. Each export holds a list of
        output Tensorflow tensors, and a list of Hailo layer names corresponding to each output
        tensor.
        """
        return self._output_tensors_exports

    @property
    def variables_exports(self):
        """
        dict: A dictionary of exports where the keys are of type :class:`VariableExportLevel` and the
        values are exports of type :class:`VariablesExport`. Each export holds a list of
        Tensorflow variables, and a list of Hailo layer names corresponding to each output
        tensor.
        """
        return self._variables_exports

    @output_tensors_exports.setter
    def output_tensors_exports(self, new_output_exports_dict):
        self._output_tensors_exports = new_output_exports_dict

    def _load_params(self):
        if self._load_params_func is None:
            return None
        return self._load_params_func(self._session, self._graph)

    @property
    def session(self):
        """
        Returns the loaded session.

        This function loads the parameters and returns the loaded session.

        Returns:
            The loaded session.
        """
        self._load_params()
        return self._session

    def get_export_by_level(self, export_level=ExportLevel.OUTPUT_LAYERS):
        """
        Retrieve an export entry from the dictionary, according to the given export level.

        Args:
            export_level (:class:`ExportLevel`): Which export to get from the dictionary.

        Returns:
            :class:`OutputTensorsExport`: Selected export, or None if the level isn't in the
            dictionary.

        """
        if export_level in self._output_tensors_exports:
            return self._output_tensors_exports[export_level]
        return None

    def get_variable_export_by_layer(self, export_level=VariableExportLevel.BIASES):
        """
        Retrieve a variable export entry from the dictionary, according to the given export level.

        Args:
            export_level (:class:`VariableExportLevel`): Which variable export to get from the dictionary.

        Returns:
            :class:`VariablesExport`: Selected export, or None if the level isn't in the dictionary.

        """
        if export_level in self._variables_exports:
            return self._variables_exports[export_level]
        return None

    def get_layers_names_by_level(self, export_level=ExportLevel.OUTPUT_LAYERS):
        """
        Retrieve a list of layers' names from an export in the dictionary, according to the given
        export level.

        Args:
            export_level (:class:`ExportLevel`): Which export to get from the dictionary.

        Returns:
            list of str: Selected export's layer names list, or None if the level isn't in the
            dictionary.

        """
        if export_level in self._output_tensors_exports:
            return self._output_tensors_exports[export_level].layers_names
        return None

    def get_variable_layers_names_by_level(self, export_level=VariableExportLevel.BIASES):
        """
        Retrieves a list of variable layer names from an export in the dictionary, according to
        the given export level.

        Args:
            export_level (:class:`VariableExportLevel`): Which variable export to get from the dictionary.

        Returns:
            list of str: Selected export layer names list, or None if the level isn't in the
            dictionary.

        """
        if export_level in self._variables_exports:
            return self._variables_exports[export_level].layers_names
        return None

    def get_original_names_by_level(self, export_level=ExportLevel.OUTPUT_LAYERS):
        """
        Retrieve a list of tensors' original names from an export in the dictionary, according to
        the given export level.

        Args:
            export_level (:class:`ExportLevel`): Which export to get from the dictionary.

        Returns:
            list of list of str: Selected export's original names list, or None if the level isn't
            in the dictionary.

        """
        if export_level in self._output_tensors_exports:
            return self._output_tensors_exports[export_level].original_names
        return None

    def get_variable_original_names_by_level(self, export_level=VariableExportLevel.BIASES):
        """
        Retrieve a list of variables original names from an export in the dictionary, according to the given
        export level.

        Args:
            export_level (:class:`VariableExportLevel`): Which variable export to get from the dictionary.

        Returns:
            list of list of str: Selected variable export original names list, or None if the level isn't in the
            dictionary.

        """
        if export_level in self._variables_exports:
            return self._variables_exports[export_level].original_names
        return None

    def add_output_tensors_export(self, new_export):
        """
        Adds an export output_tensor entry to the dictionary, according to the given export level.

        Args:
            new_export (:class:`OutputTensorsExport`): The new export to be added.

        """
        self._output_tensors_exports[new_export.export_level] = new_export

    def add_variables_export(self, new_export):
        """
        Adds an export variable entry to the dictionary, according to the given export level.

        Args:
            new_export (:class:`VariablesExport`): The new export to be added.

        """
        self._variables_exports[new_export.export_level] = new_export

    def update_original_names(self, hailo_nn):
        """
        Updates original names lists for all export entries according to their layer names lists
        and the given HN.

        Args:
            hailo_nn (str): Hailo NN JSON string, used for creating a HN object for layers lookup by
                name.

        """
        hn = HailoNN.from_parsed_hn(hailo_nn)
        for export in self._output_tensors_exports.values():
            if export.layers_names:
                export.original_names = [hn.get_layer_by_name(x).original_names for x in export.layers_names]

        for export in self._variables_exports.values():
            if export.layers_names:
                export.original_names = [hn.get_layer_by_name(x).original_names for x in export.layers_names]


class OutputTensorsExport:
    """
    A single output tensors export as part of the :class:`HailoGraphExport` object which holds
    one of this for each export level.
    """

    def __init__(self, export_level, tensors, layers_names):
        """
        Constructor for a single export. Initializes ``original_names`` as an empty list, it's
        calculated post serialization to client.

        Args:
            export_level (:class:`ExportLevel`): Export level to which the tensors belong.
            tensors (list of :obj:`tf.Tensor`): List of tensors, appended to the graph by the SDK.
            layers_names (list of str): List of layers' names, with respective layers' names for each
                tensor in ``tensors``.

        """
        self._export_level = export_level
        self._tensors = tensors
        self._layers_names = layers_names
        self._original_names = []

    @property
    def export_level(self):
        """:class:`ExportLevel`: Export level to which the tensors belong."""
        return self._export_level

    @export_level.setter
    def export_level(self, level):
        self._export_level = level

    @property
    def tensors(self):
        """List of :obj:`tf.Tensor`: List of tensors appended to the graph by the SDK."""
        return self._tensors

    @property
    def layers_names(self):
        """List of str: List of layers' names, with respective layers' names for each tensor."""
        return self._layers_names

    @property
    def original_names(self):
        """
        list of list of str: List of layers' names in the original user's model, with respective
        names for each tensor.
        """
        return self._original_names

    @original_names.setter
    def original_names(self, original_names):
        self._original_names = original_names

    def __add__(self, other):
        return OutputTensorsExport(
            self.export_level if self.export_level == other.export_level else None,
            self.tensors + other.tensors,
            self.layers_names + other.layers_names,
        )


class VariablesExport:
    """
    A single variables export as part of the :class:`HailoGraphExport` object which holds one of
    this for each export level.
    """

    def __init__(self, export_level, variables, layers_names):
        """
        Constructor for a single export. Initializes ``original_names`` as an empty list, it's
        calculated post serialization to client.

        Args:
            export_level (:class:`VariableExportLevel`): Export level to which the variables belong.
            variables (list of :obj:`tf.Variable`): List of variables, appended to the graph by the
                SDK.
            layers_names (list of str): List of layers' names, with respective layers' names for each
                variable in ``variables``.

        """
        self._export_level = export_level
        self._variables = variables
        self._layers_names = layers_names
        self._original_names = []

    @property
    def export_level(self):
        """:class:`ExportLevel`: Export level to which the variables belong."""
        return self._export_level

    @export_level.setter
    def export_level(self, level):
        self._export_level = level

    @property
    def variables(self):
        """List of :obj:`tf.Variable`: List of variables appended to the graph by the SDK."""
        return self._variables

    @property
    def layers_names(self):
        """List of str: List of layers' names, with respective layers' names for each variable."""
        return self._layers_names

    @property
    def original_names(self):
        """
        list of list of str: List of layers' names in the original user's model, with respective
        names for each variable.
        """
        return self._original_names

    @original_names.setter
    def original_names(self, original_names):
        self._original_names = original_names
