from hailo_sdk_client.hw_consts.hw_arch import HWArch
from hailo_sdk_client.post_fuser.algorithms.exceptions import LayerNormMappingException
from hailo_sdk_client.post_fuser.algorithms.fuser_algorithm import FuserAlgorithm
from hailo_sdk_client.tools.fuser.fuser_helper import FuserHelper
from hailo_sdk_common.hailo_nn.hn_definitions import (
    ActivationType,
    FeatureMultiplierType,
    HnStage,
    LayerType,
    PaddingType,
)
from hailo_sdk_common.hailo_nn.hn_layers import (
    EWMultLayer,
    EWSubLayer,
    FeatureMultiplierLayer,
    FusedStandaloneEWSubLayer,
    NormalizationLayer,
    PoolingLayer,
    ReduceMeanLayer,
)
from hailo_sdk_common.numeric_utils.normalization_params import calc_normalization_params


class LayerNormMapping(FuserAlgorithm):
    NAME = "layer_norm_mapping"

    def __init__(self, model, params, model_config, hw_arch, **kwargs):
        super().__init__(model, params, model_config, hw_arch, **kwargs)
        self._fuser_helper = FuserHelper(self._model)
        self._layers_to_broadcast = []
        self._next_index = self.model.get_next_index()

    def get_algo_config(self):
        return self._model_config

    def _setup(self):
        pass

    def _run_int(self):
        if self._hw_arch.name in HWArch.PLUTO_ARCHS:
            return

        self._split_layer_norm()
        self._fuser_helper.run_broadcast_ew(layers=self._layers_to_broadcast)

    def should_skip_algo(self):
        return True

    def log_config(self):
        pass

    def export_statistics(self):
        pass

    def _split_layer_norm(self):
        """
        Split layer normalization layers according to the formula.
        """
        new_layers = []
        layers_to_remove = []
        successors_meta_data = {}

        for layer in list(self.model):
            if layer.op == LayerType.layer_normalization:
                self.split_single_layer_norm(new_layers, layers_to_remove, layer)

        for layer in layers_to_remove:
            self.model.remove_layer(layer)

        for layer in new_layers:
            self.model.relax_new_layer_into_graph(layer, successors_meta_data)

    def split_single_layer_norm(self, new_layers, layers_to_remove, layer):
        scope = f"{layer.scope}/" if layer.scope else ""
        pred = next(iter(self.model.predecessors(layer)))
        succs = list(self.model.successors(layer))

        if not layer.rms_norm:
            first_mean_in, first_mean_out, reduced_shape = self._build_layer_norm_mean_block(layer, new_layers, 1)

            ew_sub = FusedStandaloneEWSubLayer() if self.model.net_params.stage == HnStage.HN else EWSubLayer()
            block_name, layer_name = self.get_block_and_layer_names(layer.name_without_scope)
            ew_sub.index = self._next_index
            self._next_index += 1
            ew_sub.name = f"{scope}{block_name}ew_sub_{layer_name}"
            ew_sub.output_shapes = [layer.output_shape.copy()]
            ew_sub.move_params(layer)
            self.model.add_node(ew_sub)
            new_layers.append(ew_sub)
            self._layers_to_broadcast.append(ew_sub)

        square = FeatureMultiplierLayer()
        square.feature_multiplier_type = FeatureMultiplierType.square
        square.index = self._next_index
        self._next_index += 1
        square.name = f"{scope}{block_name}square_{layer_name}"
        square.output_shapes = [layer.output_shape.copy()]
        square.move_params(layer)
        self.model.add_node(square)
        new_layers.append(square)

        second_mean_in, second_mean_out, reduced_shape = self._build_layer_norm_mean_block(layer, new_layers, 2)

        norm_epsilon = NormalizationLayer()
        norm_epsilon.activation = ActivationType.inv_sqrt
        norm_epsilon.index = self._next_index
        self._next_index += 1
        norm_epsilon.name = f"{scope}{block_name}normalization_{layer_name}"
        norm_epsilon.output_shapes = [reduced_shape]
        norm_epsilon.move_params(layer)
        self.model.add_node(norm_epsilon)
        new_layers.append(norm_epsilon)

        ew_mult = EWMultLayer()
        ew_mult.index = self._next_index
        self._next_index += 1
        ew_mult.name = f"{scope}{block_name}ew_mult_{layer_name}"
        ew_mult.output_shapes = [layer.output_shape.copy()]
        ew_mult.move_params(layer)
        self.model.add_node(ew_mult)
        new_layers.append(ew_mult)
        self._layers_to_broadcast.append(ew_mult)

        self.model.remove_edge(pred, layer)
        for succ in succs:
            self.model.remove_edge(layer, succ)

        layer_to_inputs = {
            square: [pred],
            second_mean_in: [square],
            norm_epsilon: [second_mean_out],
            ew_mult: [pred, norm_epsilon],
        }
        layer_to_inputs.update(
            {succ: [ew_mult] for succ in succs},
        )

        layer_to_outputs = {
            pred: [square, ew_mult],
            square: [second_mean_in],
            second_mean_out: [norm_epsilon],
            norm_epsilon: [ew_mult],
            ew_mult: succs,
        }

        if not layer.rms_norm:
            layer_to_inputs.update(
                {
                    first_mean_in: [pred],
                    ew_sub: [pred, first_mean_out],
                    square: [ew_sub],
                    ew_mult: [ew_sub, norm_epsilon],
                },
            )

            layer_to_outputs.update(
                {
                    pred: [ew_sub, first_mean_in],
                    first_mean_out: [ew_sub],
                    ew_sub: [ew_mult, square],
                },
            )

        for curr_layer, preds in layer_to_inputs.items():
            self._fuser_helper.add_preds(curr_layer, preds)

        for curr_layer, succs in layer_to_outputs.items():
            self._fuser_helper.add_succs(curr_layer, succs)

        if self._params:
            kernel, bias = calc_normalization_params(
                -self._params[f"{layer.name}/epsilon:0"],
                [1],
                norm_epsilon.kernel_shape,
            )
            self._params.update({f"{norm_epsilon.name}/kernel:0": kernel, f"{norm_epsilon.name}/bias:0": bias})
            self._params.remove(layer.name)
        else:
            # For fuser flow
            norm_epsilon.mean = [-layer.epsilon]
            norm_epsilon.std = [1]

        layers_to_remove.append(layer)
        self._logger.debug(f"Replaced layer normalization layer {layer.name} with layers that run on LCU")

    def _build_layer_norm_mean_block(self, layer_norm, new_layers, index):
        scope = f"{layer_norm.scope}/" if layer_norm.scope else ""
        reduce_mean, avgpool = (None,) * 2
        reduce_mean_axes = []
        avgpool_axes = []
        if layer_norm.axes:
            for axis in layer_norm.axes:
                if axis == 3 or (axis == 2 and layer_norm.input_height == 1):
                    reduce_mean_axes.append(axis)
                elif axis in [1, 2]:
                    avgpool_axes.append(axis)
                else:
                    raise LayerNormMappingException(f"Invalid axis {axis} found in {layer_norm.full_name_msg}")
        else:
            raise LayerNormMappingException(f"Invalid axes {layer_norm.axes} found in {layer_norm.full_name_msg}")
        block_name, layer_norm_name = self.get_block_and_layer_names(layer_norm.name_without_scope)
        avgpool_shape = layer_norm.input_shape.copy()
        avgpool_kernel, avgpool_strides = ([1, 1, 1, 1],) * 2
        for axis in avgpool_axes:
            avgpool_shape[axis] = 1
            avgpool_kernel[axis] = layer_norm.input_shape[axis]
            avgpool_strides[axis] = layer_norm.input_shape[axis]

        reduce_mean_shape = avgpool_shape.copy()
        for axis in reduce_mean_axes:
            reduce_mean_shape[axis] = 1

        if avgpool_axes:
            avgpool = PoolingLayer()
            avgpool.op = LayerType.avgpool
            avgpool.padding = PaddingType.valid
            avgpool.index = self._next_index
            self._next_index += 1
            avgpool.name = f"{scope}{block_name}avgpool{index}_{layer_norm_name}"
            avgpool.output_shapes = [avgpool_shape.copy()]
            avgpool.move_params(layer_norm)
            self.model.add_node(avgpool)
            new_layers.append(avgpool)
            avgpool.kernel_shape = avgpool_kernel
            avgpool.strides = avgpool_strides

        if reduce_mean_axes:
            reduce_mean = ReduceMeanLayer()
            reduce_mean.index = self._next_index
            self._next_index += 1
            reduce_mean.name = f"{scope}{block_name}reduce_mean{index}_{layer_norm_name}"
            reduce_mean.output_shapes = [reduce_mean_shape.copy()]
            reduce_mean.reduce_axes = reduce_mean_axes
            reduce_mean.move_params(layer_norm)
            self.model.add_node(reduce_mean)
            new_layers.append(reduce_mean)

        if avgpool is None:
            return reduce_mean, reduce_mean, reduce_mean_shape

        if reduce_mean is None:
            return avgpool, avgpool, avgpool_shape

        self.model.add_edge(avgpool, reduce_mean)
        avgpool.append_output_layer(reduce_mean.name)
        avgpool.append_output_index(reduce_mean.index)
        avgpool.output_shapes = [avgpool_shape]
        reduce_mean.append_input_layer(avgpool.name)
        reduce_mean.append_input_index(avgpool.index)
        reduce_mean.input_shapes = [avgpool_shape]
        return avgpool, reduce_mean, reduce_mean_shape
