from hailo_model_optimization.acceleras.utils.acceleras_definitions import PostprocessTarget
from hailo_sdk_client.tools.fuser.fuser_helper import FuserHelper, PostprocessAdditionMode
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers import ArgmaxLayer, SoftmaxLayer
from hailo_sdk_common.logger.logger import default_logger


class LogitsLayersAdditionException(Exception):
    pass


class LogitsLayersAdder:
    MAX_CHANNELS_SUPPORTED_ARGMAX = 64
    PPU_MAX_SUPPORTED_SOFTMAX_LAYERS = 2
    SOFTMAX_SUPPORTED_RANK = 2
    ARGMAX_SUPPORTED_RANK = 4

    def __init__(self, hailo_nn, layers, logits_layers_name, activation_type, axis, engine):
        self._hn = hailo_nn
        self._layers = layers
        self._logits_layers_name = logits_layers_name
        self._activation_type = activation_type
        self._axis = axis
        self._engine = engine
        self._logger = default_logger()

    @property
    def logits_layers_name(self):
        return self._logits_layers_name

    @property
    def activation_type(self):
        return self._activation_type

    @property
    def axis(self):
        return self._axis

    @property
    def engine(self):
        return self._engine

    def add_logits_layers(self):
        if self._engine == PostprocessTarget.NN_CORE:
            self.add_logits_layers_to_hn()
        else:
            FuserHelper.add_logits_as_postprocess_layer_to_hn(
                self._hn,
                self._layers,
                self._activation_type,
                self._axis,
                self._logits_layers_name,
                PostprocessAdditionMode.ADD_AS_OUTPUT,
            )

        return self._hn

    def add_logits_layers_to_hn(self):
        if any(layer.op == LayerType.postprocess for layer in self._layers):
            raise LogitsLayersAdditionException(
                "nn_core logits layer can be applied only on layers that run on nn_core. "
                "Please change to `engine=cpu`",
            )

        if self._activation_type == LayerType.softmax:
            self.validate_softmax_conditions(self._layers)
            for i, layer in enumerate(self._layers):
                softmax = SoftmaxLayer.create(
                    self._logits_layers_name[i],
                    layer.name,
                    axis=self._axis,
                    output_shapes=layer.output_shapes,
                )
                self._hn.push_layer(softmax, [layer])
                self._logits_layers_name[i] = softmax.name
                if layer.name in self._hn.net_params.output_layers_order:
                    # update output layer order
                    output_index = self._hn.net_params.output_layers_order.index(layer.name)
                    self._hn.net_params.output_layers_order[output_index] = softmax.name

        elif self._activation_type == LayerType.argmax:
            self.validate_argmax_conditions(self._layers)
            for i, layer in enumerate(self._layers):
                argmax = ArgmaxLayer.create(self._logits_layers_name[i], layer.name, output_shapes=layer.output_shapes)
                self._hn.push_layer(argmax, [layer])
                output_index = self._hn.net_params.output_layers_order.index(layer.name)
                self._hn.net_params.output_layers_order[output_index] = argmax.name
                self._logits_layers_name[i] = argmax.name

    def validate_softmax_conditions(self, layers):
        log = ""
        for layer in layers:
            successors = list(self._hn.successors(layer))
            output_layer = [successor for successor in successors if successor.op == LayerType.output_layer]
            if len(output_layer) != 1:
                log += f"The layer {layer.name} must be an output layer, did you mean {self._hn.net_params.output_layers_order[0]}?\n"
        if log:
            raise LogitsLayersAdditionException(log)

    def validate_argmax_conditions(self, layers):
        log = ""
        for layer in layers:
            output_shape = len(layer.output_shape)
            successors = list(self._hn.successors(layer))
            output_layer = [successor for successor in successors if successor.op == LayerType.output_layer]
            if len(output_layer) != 1:
                log += f"The layer {layer.name} must be an output layer"
            if output_shape != self.ARGMAX_SUPPORTED_RANK:
                log += f"Can't add Argmax layer. The predecessor layer must have rank {self.ARGMAX_SUPPORTED_RANK}."
            if self._axis is not None and self._axis not in [(output_shape - 1), -1]:
                log += "Can't add Argmax layer. The axis must represent the channels dimension."
            if layer.input_shape[-1] > self.MAX_CHANNELS_SUPPORTED_ARGMAX:
                log += (
                    f"Can't add Argmax layer. The maximum supported channels are {self.MAX_CHANNELS_SUPPORTED_ARGMAX}."
                )

        if log:
            raise LogitsLayersAdditionException(log)
