from hailo_sdk_common.hailo_nn.hailo_nn import hn_to_npz_key
from hailo_sdk_common.hailo_nn.hn_layers import NormalizationLayer
from hailo_sdk_common.hailo_nn.hn_layers.layer_common import MODEL_SCRIPT_LAYER_PREFIX
from hailo_sdk_common.logger.logger import default_logger


class NormalizationLayersAdditionException(Exception):
    pass


class NormalizationLayersAdder:
    """
    Representing object responsible for adding normalization layers after input layers.

    Args:
        hailo_nn (:class:`~hailo_sdk_common.hailo_nn.hailo_nn.HailoNN`): The Hailo NN to add Normalization layers to.
        params (:class:`~hailo_sdk_common.model_params.model_params.ModelParams`): The params to add Normalization
            layer params.
        mean (list of floats): Mean values for each feature in the input layer.
        std (list of floats): Std values for each feature in the input layer.
        normalization_layers (list of str, optional): Names for the new Normalization layers.
        input_layer_name (str, optional): input layer to normalize.
            default to all input layers, in that case, all input layers must have the same feature number.

    """

    def __init__(self, hailo_nn, params, mean, std, normalization_layers=None, input_layer_name=None):
        self._hailo_nn = hailo_nn
        self._params = params
        self._mean = mean
        self._std = std
        self._input_layers = self._get_input_names(input_layer_name)
        self._normalization_names = self._get_normalization_names(normalization_layers)
        self._logger = default_logger()
        self._graph_next_index = self._hailo_nn.get_next_index()

    @property
    def input_layers(self):
        return self._input_layers

    @property
    def mean(self):
        return self._mean

    @property
    def std(self):
        return self._std

    @property
    def normalization_names(self):
        return self._normalization_names

    def _get_normalization_names(self, normalization_layers):
        num_of_inputs_to_normalize = len(self._input_layers)
        if not normalization_layers:
            return [f"normalization{idx + 1}" for idx in range(num_of_inputs_to_normalize)]

        if len(normalization_layers) == num_of_inputs_to_normalize:
            return normalization_layers

        raise NormalizationLayersAdditionException(
            f"Given {len(normalization_layers)} names for the new layers when "
            f"{num_of_inputs_to_normalize} names are required",
        )

    def add_normalization_layers(self):
        """
        Add normalization layers.
        """
        self._logger.verbose("Adding normalization layers")

        for idx, input_layer in enumerate(self._input_layers):
            if len(self._std) != input_layer.output_features:
                raise NormalizationLayersAdditionException(
                    f"{input_layer.name} has {input_layer.output_features} features and"
                    f" given {len(self._std)} std values. They must be equal",
                )
            if len(self._mean) != input_layer.output_features:
                raise NormalizationLayersAdditionException(
                    f"{input_layer.name} has {input_layer.output_features} features and"
                    f" given {len(self._mean)} mean values. They must be equal",
                )
            normalization_layer = self._create_normalization_layer(idx, self._mean, self._std)
            self._hailo_nn.push_layer(normalization_layer, [input_layer])
            self._set_normalization_layer_params(normalization_layer)
            orig_name = f"{MODEL_SCRIPT_LAYER_PREFIX}_{normalization_layer.name_without_scope}"
            normalization_layer.add_original_name(orig_name)

        return self._hailo_nn, self._params

    def _set_normalization_layer_params(self, normalization_layer):
        self._params[hn_to_npz_key(normalization_layer.name, "kernel")] = normalization_layer.kernel
        self._params[hn_to_npz_key(normalization_layer.name, "bias")] = normalization_layer.bias

    def _create_normalization_layer(self, idx, mean, std):
        normalization_layer = NormalizationLayer()
        normalization_layer.name = self._normalization_names[idx]
        normalization_layer.index = self._graph_next_index + idx
        normalization_layer.mean = mean
        normalization_layer.std = std
        return normalization_layer

    def _get_input_names(self, input_name_to_normalize):
        input_layers = self._hailo_nn.get_input_layers()
        if not input_name_to_normalize:
            return input_layers
        input_layer_to_normalize = self._hailo_nn.get_layer_by_name(input_name_to_normalize)
        if input_layer_to_normalize not in input_layers:
            raise NormalizationLayersAdditionException(f"{input_name_to_normalize.name} is not a valid input name")

        return [input_layer_to_normalize]
