import numpy as np

from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_sdk_client.post_fuser.algorithms.fuser_algorithm import FuserAlgorithm
from hailo_sdk_client.tools.fuser.fuser_helper import FuserHelper
from hailo_sdk_common.hailo_nn.hailo_nn import HailoNN, hn_to_npz_key
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType, BlockType
from hailo_sdk_common.hailo_nn.hn_layers.fused_standalone_activation_layer import FusedStandaloneActivationLayer
from hailo_sdk_common.hailo_nn.hn_layers.fused_standalone_ew_add import FusedStandaloneEWAddLayer
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_activation import LayerWithActivation
from hailo_sdk_common.hailo_nn.hn_layers.normalization import NormalizationLayer
from hailo_sdk_common.model_params.model_params import ModelParams
from hailo_sdk_common.numeric_utils.normalization_params import calc_normalization_params


class SplitLeakyAndPReLUWithNegSlope(FuserAlgorithm):
    NAME = "split_leaky_with_neg_slope"

    def __init__(self, model: HailoNN, params: ModelParams, model_config: ModelOptimizationConfig, hw_arch, **kwargs):
        super().__init__(model, params, model_config, hw_arch, **kwargs)
        self._fuser_helper = FuserHelper(model)
        self._model_config = model_config

    def get_algo_config(self):
        return self._model_config

    def _run_int(self):
        # replaces leaky relu prelu with negative slope with equivalent block
        for layer in list(self._model):
            if isinstance(layer, LayerWithActivation):
                leaky_key = hn_to_npz_key(layer.name, "leaky_alpha")
                slope_value = self.params.get(leaky_key, 1)

                if slope_value < 0 and self._hw_arch.is_mercury_arch:
                    # mercury supports negative slope for leaky relu
                    continue

                if slope_value >= 0:
                    # might be prelu
                    prelu_slope_key = hn_to_npz_key(layer.name, "prelu_slope")
                    slope_value = self.params.get(prelu_slope_key, [1])
                    if all(np.array(slope_value) >= 0) and len(np.unique(slope_value)) == 1:
                        continue
                    if len(slope_value) == 1:
                        slope_value = np.concatenate([slope_value] * layer.output_features)
                    slope_value = [-1 / s for s in slope_value]
                else:
                    slope_value = np.concatenate([[-1 / slope_value]] * layer.output_features)

                base_index = self._model.get_next_index()
                neg_normalization = NormalizationLayer()
                neg_normalization.name = f"{layer.name}_normalization{base_index}"
                neg_normalization.index = base_index
                neg_normalization.original_names = layer.original_names.copy()
                neg_normalization.input_shapes = [layer.output_shapes[0].copy()]
                neg_normalization.output_shapes = [layer.output_shapes[0].copy()]
                neg_normalization.block_info = (BlockType.PRELU, layer.name)
                neg_normalization.activation = ActivationType.relu
                base_index += 1

                relu = FusedStandaloneActivationLayer()
                relu.name = f"{layer.name}_relu{base_index}"
                relu.index = base_index
                relu.original_names = layer.original_names.copy()
                relu.input_shapes = [layer.output_shapes[0].copy()]
                relu.output_shapes = [layer.output_shapes[0].copy()]
                relu.activation = ActivationType.relu
                relu.block_info = (BlockType.PRELU, layer.name)
                base_index += 1

                ew_add = FusedStandaloneEWAddLayer()
                ew_add.index = base_index
                ew_add.name = f"{layer.name}_ew_add{base_index}"
                ew_add.original_names = layer.original_names.copy()
                ew_add.input_shapes = [layer.output_shapes[0].copy(), layer.output_shapes[0].copy()]
                ew_add.output_shapes = layer.output_shapes.copy()
                ew_add.block_info = (BlockType.PRELU, layer.name)
                base_index += 1

                alpha_normalization = NormalizationLayer()
                alpha_normalization.name = f"{layer.name}_normalization2{base_index}"
                alpha_normalization.index = base_index
                alpha_normalization.original_names = layer.original_names.copy()
                alpha_normalization.input_shapes = [layer.output_shapes[0].copy()]
                alpha_normalization.output_shapes = [layer.output_shapes[0].copy()]
                alpha_normalization.block_info = (BlockType.PRELU, layer.name)
                base_index += 1

                f_in = layer.output_features
                neg_normalization_kernel, neg_normalization_bias = calc_normalization_params(
                    mean=np.concatenate([[0]] * f_in),
                    std=np.concatenate([[-1]] * f_in),
                    kernel_shape=[1, 1, layer.output_features, 1],
                )
                alpha_normalization_kernel, alpha_normalization_bias = calc_normalization_params(
                    mean=np.concatenate([[0]] * f_in),
                    std=slope_value,
                    kernel_shape=[1, 1, layer.output_features, 1],
                )
                new_param = {
                    hn_to_npz_key(neg_normalization.name, "kernel"): neg_normalization_kernel,
                    hn_to_npz_key(neg_normalization.name, "bias"): neg_normalization_bias,
                    hn_to_npz_key(alpha_normalization.name, "kernel"): alpha_normalization_kernel,
                    hn_to_npz_key(alpha_normalization.name, "bias"): alpha_normalization_bias,
                }

                self.params.update(new_param)
                succs = list(self._model.successors(layer)).copy()
                for succ in succs:
                    self._fuser_helper.remove_succ(layer, succ)
                    self._fuser_helper.remove_pred(succ, layer)
                    self._fuser_helper.add_preds(succ, [ew_add])

                layer.activation = ActivationType.linear
                self.params.remove(leaky_key)
                if layer.name in self.model.net_params.output_layers_order:
                    self.model.net_params.output_layers_order[
                        self.model.net_params.output_layers_order.index(layer.name)
                    ] = ew_add.name

                # connects the new layers to the graph
                self._fuser_helper.add_succs(layer, [relu], update_output_shapes=False)

                self._fuser_helper.add_preds(relu, [layer], update_input_shapes=False)
                self._fuser_helper.add_succs(relu, [ew_add], update_output_shapes=False)

                self._fuser_helper.add_preds(neg_normalization, [layer], update_input_shapes=False)
                self._fuser_helper.add_succs(neg_normalization, [alpha_normalization], update_output_shapes=False)

                self._fuser_helper.add_preds(alpha_normalization, [neg_normalization], update_input_shapes=False)
                self._fuser_helper.add_succs(alpha_normalization, [ew_add], update_output_shapes=False)

                self._fuser_helper.add_preds(ew_add, [relu, alpha_normalization], update_input_shapes=False)
                self._fuser_helper.add_succs(ew_add, succs)

    def _setup(self):
        pass

    def log_config(self):
        pass

    def should_skip_algo(self):
        return False
