#!/usr/bin/env python

import numpy as np

from hailo_sdk_client.post_fuser.algorithms.fuser_algorithm import FuserAlgorithm
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType, LayerType
from hailo_sdk_common.hailo_nn.hn_layers import NormalizationLayer
from hailo_sdk_common.model_params.model_params import ModelParams


class AddPreLnNormalization(FuserAlgorithm):
    NAME = "add_pre_ln_normalization"
    # TODO - note that I added this algorithm only for a specific use we want for unet model in which we want to add pre-ln normalization dummy layer
    # to support equaliztion on pre -ln layer conv  This algo is added on skip =True, and we cann in hte future add alls to support it

    def __init__(self, model, params, model_config, hw_arch, **kwargs):
        super().__init__(model, params, model_config, hw_arch)

    def get_algo_config(self):
        pass

    def _setup(self):
        pass

    def _run_int(self):
        layers = [layer for layer in self._model if layer.op == LayerType.layer_normalization]
        for layer_norm in layers:
            predeccesor = self.model.get_layer_by_name(self.model.get_layer_name_by_index(layer_norm.input_indices[0]))

            # separates the activation from the layer
            base_index = self.model.get_next_index()
            scope = layer_norm.name.split("/")[0]
            name = layer_norm.name.split("/")[1]
            new_normalization_name = f"{scope}/pre_ln_equalization_consumer_{name}"
            normalization = NormalizationLayer()
            normalization.name = new_normalization_name
            normalization.index = base_index
            normalization.original_names = layer_norm.original_names.copy()
            normalization.input_shapes = [layer_norm.output_shapes[0].copy()]
            normalization.output_shapes = [layer_norm.output_shapes[0].copy()]
            normalization.mean = 0
            normalization.std = 1
            # normalization.block_info = (BlockType.PRELU, layer.name)
            normalization.activation = ActivationType.linear
            base_index += 1

            self.model.push_layer(normalization, [predeccesor], calc_shapes=False)
            normalization.output_indices = predeccesor.input_indices.copy()
            predeccesor.output_indices[predeccesor.output_indices.index(layer_norm.index)] = normalization.index

            new_kernel = np.ones([1, 1, normalization.output_features, 1])
            new_bias = np.zeros([normalization.output_features])
            new_normalization_params = {
                f"{new_normalization_name}/kernel:0": new_kernel.astype(np.float32),
                f"{new_normalization_name}/bias:0": new_bias.astype(np.float32),
            }

            self._params.update(new_normalization_params)
            if new_normalization_name not in self._params.layers:
                self._params = ModelParams(self._params.params)

    def should_skip_algo(self):
        # this is added with skip!!!!
        return True

    def log_config(self):
        pass
