from hailo_sdk_client.tools.fuser.fuser_helper import FuserHelper
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType
from hailo_sdk_common.hailo_nn.hn_layers import FusedStandaloneActivationLayer, LayerWithActivation
from hailo_sdk_common.hailo_nn.hn_layers.layer_common import MODEL_SCRIPT_LAYER_PREFIX
from hailo_sdk_common.logger.logger import default_logger


class OutputActivationModifierException(Exception):
    pass


class OutputActivationModifier:
    """
    Representing object responsible for adding standalone activation layers after output layers or
    modifying predecessors activation when possible.

    Args:
        hailo_nn (:class:`~hailo_sdk_common.hailo_nn.hailo_nn.HailoNN`): The Hailo NN to add Standalone
            Activation layers to, or modify its output activations.
        params (:class:`~hailo_sdk_common.model_params.model_params.ModelParams`): The params to add Activation
            layer params.
        activation: The activation type to be used in the new standalone activation layers.
        output_layer_name (str, optional): output layer to add standalone activation afterwards or
            modify its activation if possible, defaults to all output layers.
        activation_layers_names (list of str, optional): Names for the new activation layers.
    """

    def __init__(self, hailo_nn, params, activation, output_layer_name=None, activation_layers_names=None):
        self._hailo_nn = hailo_nn
        self._fuser_helper = FuserHelper(hailo_nn)
        self._params = params
        self._activation = activation
        self._output_layers = output_layer_name
        self._activation_layer_names = self._get_activation_layer_names(activation_layers_names)
        self._logger = default_logger()
        self._graph_next_index = self._hailo_nn.get_next_index()

    @property
    def output_layers(self):
        return self._output_layers

    @property
    def activation_type(self):
        return self._activation

    @property
    def activation_layer_names(self):
        return self._activation_layer_names

    def _get_activation_layer_names(self, activation_layers_names):
        num_of_outputs = len(self._output_layers)
        if not activation_layers_names:
            return [f"{self.activation_type.value}{idx + 1}" for idx in range(num_of_outputs)]

        if len(activation_layers_names) == num_of_outputs:
            return activation_layers_names

        raise OutputActivationModifierException(
            f"Given {len(activation_layers_names)} names for the new activations when "
            f"{activation_layers_names} names are required",
        )

    def apply_activation_modification(self):
        """
        Add standalone activation layers or modify pred activation if possible.
        """
        self._logger.verbose("Adding standalone activation layers or modifying predecessors activation.")
        original_activations = []

        for idx, output_layer in enumerate(self._output_layers):
            if isinstance(output_layer, LayerWithActivation) and output_layer.is_activation_fusible:
                original_activations.append(output_layer.activation)
                output_layer.activation = self._activation
                self.activation_layer_names[idx] = output_layer.name
            else:
                original_activations.append(ActivationType.linear)
                standalone_activation_layer = self._create_standalone_activation_layer(idx)
                self._hailo_nn.push_layer(standalone_activation_layer, [output_layer])
                orig_name = f"{MODEL_SCRIPT_LAYER_PREFIX}_{standalone_activation_layer.name_without_scope}"
                standalone_activation_layer.add_original_name(orig_name)
                self.activation_layer_names[idx] = standalone_activation_layer.name

        return self._hailo_nn, original_activations

    def _create_standalone_activation_layer(self, idx):
        standalone_activation_layer = FusedStandaloneActivationLayer()
        standalone_activation_layer.name = self._activation_layer_names[idx]
        standalone_activation_layer.index = self._graph_next_index + idx
        standalone_activation_layer.activation = self._activation
        return standalone_activation_layer
