from copy import copy
from dataclasses import dataclass
from functools import reduce

import networkx as nx
import numpy as np
from numpy.typing import NDArray

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv import HailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_sub import HailoElementwiseSub
from hailo_model_optimization.acceleras.hailo_layers.hailo_feature_multiplier_on_mac import HailoFeatureMultiplierOnMac
from hailo_model_optimization.acceleras.hailo_layers.hailo_io import HailoInputLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_layer_norm import HailoLayerNorm
from hailo_model_optimization.acceleras.hailo_layers.hailo_layer_norm_mercury import HailoLayerNormMercury
from hailo_model_optimization.acceleras.hailo_layers.hailo_layer_normalization import HailoLayerNormalization
from hailo_model_optimization.acceleras.hailo_layers.hailo_precision_split import HailoPrecisionSplit
from hailo_model_optimization.acceleras.hailo_layers.op_factories import gen_acceleras_layers_from_hn
from hailo_model_optimization.acceleras.lossy_elements.quant_element import APUOutputQuantElement
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_base import CommandMeta
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerAdaRoundConfig,
    LayerBiasCorrectionConfig,
    LayerEqualizationConfig,
    LayerNegExponentConfig,
    LayerPrecisionConfig,
    LayerTranslationConfig,
    LayerZeroStaticChannelsConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    BiasMode,
    EWMultType,
    FeaturePolicy,
    FinetunePolicy,
    LayerNormDecompositionMode,
    LayerNormMode,
    OptimizationTarget,
    PrecisionMode,
    ThreeWayPolicy,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasImplementationError
from hailo_model_optimization.algorithms.mixed_precision.create_mixed_precision import CreateMixedPrecision
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm
from hailo_model_optimization.algorithms.stats_collection.stats_collection import StatsCollector

SUPPORTED_LAYERS = [HailoLayerNormalization, HailoLayerNormMercury, HailoLayerNorm]
DEBUG = False


@dataclass
class EquivClassNorm:
    source: str
    consumer_square: str
    consumer_out: str


@dataclass
class OnlineTokenEqualizationBlock:
    exp_decompose: str
    shift: str


class DecomposeLayerNorm(OptimizationAlgorithm):
    def __init__(
        self,
        model: HailoModel,
        model_config,
        logger_level,
        dataset,
        **kwargs,
    ):
        super().__init__(
            model,
            model_config,
            logger_level=logger_level,
            name="LayerNorm Decomposition",
            **kwargs,
        )
        self._unbatched_dataset = dataset
        self._logger_level_other = 9
        self._layer_norm_name_split_cache = {}

    def _get_default_mode(self, mode):
        if mode == LayerNormMode.auto:
            if self.optimization_target == OptimizationTarget.PLUTO:
                return LayerNormMode.ppu
            return LayerNormMode.nn_core
        else:
            return mode

    def _setup(self):
        config = self.get_algo_config()
        self._mode = self._get_default_mode(config.mode)
        self._equalization = config.equalization == ThreeWayPolicy.enabled
        self._square_12_bit = config.square_12_bit
        self._optimize_ew_mult = config.equalization and self._square_12_bit
        self._token_equalization = config.token_equalization == FeaturePolicy.enabled
        self._add_buffer_layer = config.add_buffer_layer == FeaturePolicy.enabled
        self._equalization_info_by_layer = dict()
        self._online_token_equalization_info_by_layer = dict()
        self._precision_split_layers = set()
        self._summed_squares = list()
        self._verify()
        if DEBUG:
            self._logger.info(f"self.optimization_target {self.optimization_target}")
            self._logger.info(f"self._mode {self._mode}")
            self._logger.info(f"self._equalization  {self._equalization}")
            self._logger.info(f"self._square_12_bit {self._square_12_bit}")
            self._logger.info(f"self._optimize_ew_mult {self._optimize_ew_mult}")
            self._logger.info(f"self._token_equalization {self._token_equalization}")
            self._logger.info(f"self._add_buffer_layer {self._add_buffer_layer}")

    def _verify(self):
        if self.optimization_target == OptimizationTarget.SAGE and self._mode == LayerNormMode.ppu:
            raise AccelerasImplementationError("SAGE does not support PPU mode")
        all_16bit_inp = np.all([self._is_16_bit_inp(layer) for layer in self.norm_layers])
        if self._mode == LayerNormMode.ppu and not all_16bit_inp:
            raise AccelerasImplementationError("PPU mode must have all layers as 16 bit input")

    def should_skip_algo(self):
        return len(self.norm_layers) == 0

    def get_algo_config(self):
        return self._model_config.layer_norm_decomposition

    def _run_int(self):
        nn_core = self._mode == LayerNormMode.nn_core
        if self.optimization_target == OptimizationTarget.SAGE or nn_core:
            # Decomposition
            self._run_int_core()
        # if force_hw was forced to be mercury or pluto we need to handle this cases - rest of the cases are handled alredy
        elif self.optimization_target == OptimizationTarget.MERCURY:
            # PPU scheme mercury
            self._run_int_mercury()
        elif self.optimization_target == OptimizationTarget.PLUTO:
            # PPU scheme pluto
            self._run_int_pluto()

    def _run_int_mercury(self):
        for layer in self.norm_layers:
            self._build_mercury(layer)

    def _run_int_pluto(self):
        for layer in self.norm_layers:
            self._build_pluto(layer)

        algo = CreateMixedPrecision(
            model=self._model,
            model_config=self._model_config,
            logger_level=self._logger_level_other,
            logger=self._logger,
        )
        algo.run()

    def _run_int_core(self):
        cfg = self.get_algo_config()
        for layer in self.norm_layers:
            # Step 1: decompose the layer norms (Layer norm, group norm, and RMS norm)
            layer_norm_flow = self.decompose_single_layer_norm(layer)
            # Step 2: Add equalization to the flow if needed
            if not cfg.eq_consumer:
                self._remove_equalization_on_input_layer_norm(layer)
            if self._equalization:
                self.add_equalization_to_flow(layer_norm_flow, layer)
            if self._token_equalization:
                self.add_online_token_equalization_to_flow(layer_norm_flow, layer)
                if self._add_buffer_layer:
                    self._add_shortcut_input_buffer(layer_norm_flow, layer)

            # Step 3: Add the flow to the model
            self.add_flow_to_model(layer_norm_flow, layer)

        # Step 4: Add precision to layers
        algo = CreateMixedPrecision(
            model=self._model,
            model_config=self._model_config,
            logger_level=self._logger_level_other,
            logger=self._logger,
        )
        algo.run()

        # Step 5: Collect stats again after the new layers has been added
        if self._run_statistics:
            stats_collector = StatsCollector(
                self._model,
                self._model_config,
                self._logger_level_other,
                self._unbatched_dataset,
                layers_to_handle=self._layers_to_collect,
                logger=self._logger,
            )
            stats_collector.run()
            for square_layer, mean_layer in self._summed_squares:
                self._calibrate_summed_squares(square_layer, mean_layer)

            if len(self._summed_squares) > 0:
                algo = CreateMixedPrecision(
                    model=self._model,
                    model_config=self._model_config,
                    logger_level=self._logger_level_other,
                    logger=self._logger,
                )
                algo.run()
            for equiv_class in self._equalization_info_by_layer.values():
                self.equalize_layer(equiv_class)
            for lname in self._precision_split_layers:
                self._split_layer(self._model, lname)
            for equiv_class in self._online_token_equalization_info_by_layer.values():
                self.mask_online_token_equalization(equiv_class)

    @property
    def _layers_to_collect(self):
        layers_to_collect = set()
        for lname, equiv_set in self._equalization_info_by_layer.items():
            layers_to_collect.add(equiv_set.source)
            layers_to_collect.add(equiv_set.consumer_square)
            layers_to_collect.add(equiv_set.consumer_out)
        for lname in self._precision_split_layers:
            layers_to_collect.add(lname)
        for lname, equiv_set in self._online_token_equalization_info_by_layer.items():
            layers_to_collect.add(equiv_set.exp_decompose)
        for square_layer, mean_layer in self._summed_squares:
            layers_to_collect.add(square_layer.full_name)
        return layers_to_collect

    @property
    def _run_statistics(self):
        return (
            self._equalization_info_by_layer
            or self._precision_split_layers
            or self._summed_squares
            or self._online_token_equalization_info_by_layer
        )

    @property
    def norm_layers(self):
        res = []
        for layer in self._model.flow.toposort():
            acceleras_layer = self._model.layers[layer]
            if type(acceleras_layer) in SUPPORTED_LAYERS:
                res.append(acceleras_layer)
        return res

    def finalize_global_cfg(self, algo_config):
        if algo_config.bit_decomposition_mode == LayerNormDecompositionMode.auto:
            # TODO : Add trainable to model state.
            # TODO: https://hailotech.atlassian.net/browse/SDK-59573
            if self._model_config.finetune.policy == FinetunePolicy.enabled:
                algo_config.bit_decomposition_mode = LayerNormDecompositionMode.split_precision
            else:
                algo_config.bit_decomposition_mode = LayerNormDecompositionMode.uniform_precision
        if (
            algo_config.equalization == ThreeWayPolicy.enabled
            and algo_config.bit_decomposition_mode == LayerNormDecompositionMode.uniform_precision
        ):
            raise NotImplementedError("Equalization is not supported with uniform precision")
        if algo_config.equalization == ThreeWayPolicy.allowed:
            if algo_config.bit_decomposition_mode == LayerNormDecompositionMode.uniform_precision:
                algo_config.equalization = ThreeWayPolicy.disabled
            else:
                algo_config.equalization = ThreeWayPolicy.enabled

    def _get_valid_layer_cfg(self, lname, cfg):
        return cfg

    @staticmethod
    def _is_16_bit_inp(layer_norm: HailoLayerNormalization):
        value = layer_norm.get_precision_mode().reduce()
        return value == PrecisionMode.a16_w16

    def add_flow_to_model(self, layer_norm_flow: nx.DiGraph, layer_norm):
        # Remove old layer norm from model & model flow
        del self._model.layers[layer_norm.full_name]
        self._model.flow.remove_node(layer_norm.full_name)
        self._model_config.remove_layer_from_all_configs(layer_norm.full_name)

        # Add nodes to the model & model flow
        for lname in nx.topological_sort(layer_norm_flow):
            layer = layer_norm_flow.nodes[lname]["layer"]
            if layer is not None:
                if layer.full_name in self._model.layers or layer.full_name in self._model.flow.nodes:
                    raise ValueError(f"{layer.full_name} already exists in model, Layer norm decomposition failed")
                self._model._unlock_model()
                self._model.layers[layer.full_name] = layer
                self._model._lock_model()
                self._model.flow.add_node(layer.full_name)

        # Add edges to the model flow
        edges_graph = nx.line_graph(layer_norm_flow)
        for u, v in nx.topological_sort(edges_graph):  # topological sort for edges
            edge_attr = layer_norm_flow.edges[(u, v)]
            self._model.flow.add_edge(u, v, **edge_attr)

        # Fix output layer order if applies
        if layer_norm.full_name in self._model.flow.output_layer_order:
            out_nodes = [u for (u, _), d in edges_graph.out_degree if d == 0]
            if len(out_nodes) != 1:
                raise RuntimeError(f"Unexpected output nodes when decomposing layer norm {out_nodes}")
            idx = self._model.flow.output_layer_order.index(layer_norm.full_name)
            self._model.flow.output_layer_order[idx] = out_nodes[0]

    def _add_shortcut_input_buffer(self, layer_norm_flow, layer_norm):
        entry_node, input_shape = self._find_rms_root_layer(layer_norm_flow, layer_norm)
        # rule design for llm block only for rms norm:
        if entry_node not in self._model.layers or isinstance(self._model.layers[entry_node], HailoInputLayer):
            # predecessor is not input layer
            return
        # if input layer has only one successor, we don't need to add shortcut input buffer before token equalization
        elif len(list(self._model.flow.successors(entry_node))) == 1:
            # predecessor is input layer but input layer has only one successor
            return

        self._logger.debug(f"Adding shortcut input buffer to {entry_node}")
        input_shape = self._model.layers[entry_node].output_shape
        shortcut_layer_name = self._get_new_layer_name("shortcut_input_buffer", layer_norm)
        shortcut_layer = self.create_standalone_activation(
            shortcut_layer_name,
            input_shape,
            activation="linear",
            precision_mode="a16_w16",
            bias_mode="double_scale_initialization",
            original_names=layer_norm.hn_element.get("original_names", []),
        )

        layer_norm_flow.add_node(shortcut_layer_name, layer=shortcut_layer)

        # edges_graph = nx.line_graph(layer_norm_flow)
        # for u, v in nx.topological_sort(edges_graph):
        #     if u == entry_node:
        #         edge_attr = layer_norm_flow.edges[(u, v)]
        #         # remove edge
        #         layer_norm_flow.remove_edge(u, v)
        #         # add edges
        #         layer_norm_flow.add_edge(shortcut_layer_name, v, **edge_attr)

        self._insert_block(layer_norm_flow, list(layer_norm_flow.out_edges(entry_node)), shortcut_layer_name)
        layer_norm_flow.add_edge(entry_node, shortcut_layer_name)

    def add_equalization_to_flow(self, layer_norm_flow: nx.DiGraph, layer_norm: HailoLayerNormalization):
        precision_mode = layer_norm.get_precision_mode().reduce()
        if precision_mode == PrecisionMode.a16_w16:
            bias_mode = BiasMode.single_scale_decomposition
        else:
            bias_mode = BiasMode.double_scale_initialization

        # Equalization source was added earlier (maybe it should be only in decomposed case)
        eq_src_name = self._get_new_layer_name("equalization_source", layer_norm)
        if eq_src_name not in layer_norm_flow.nodes:
            self.add_equalization_source(layer_norm_flow, layer_norm, precision_mode, bias_mode)

        eq_consumer_name = self.add_equalization_consumer(layer_norm_flow, layer_norm, precision_mode, bias_mode)

        eq_square_consumer_name = self.find_equalization_squared_consumer(layer_norm_flow, eq_src_name)

        self._equalization_info_by_layer[layer_norm.full_name] = EquivClassNorm(
            eq_src_name, eq_square_consumer_name, eq_consumer_name
        )

    def find_equalization_squared_consumer(self, layer_norm_flow, eq_src_name):
        src_descendants_flow = layer_norm_flow.subgraph(nx.descendants(layer_norm_flow, eq_src_name))
        eq_square_consumer_name = None
        for node in nx.topological_sort(src_descendants_flow):
            layer = src_descendants_flow.nodes[node]["layer"]
            if isinstance(layer, HailoConv):
                eq_square_consumer_name = layer.full_name
                break
        return eq_square_consumer_name

    def add_equalization_consumer(self, layer_norm_flow, layer_norm, precision_mode, bias_mode):
        # find output node
        output_nodes = [node for node, degree in layer_norm_flow.out_degree if degree == 0]

        # find predecessor of output node
        predecessors = set()
        for output_node in output_nodes:
            for node in layer_norm_flow.predecessors(output_node):
                predecessors.add(node)
        if len(predecessors) != 1:
            raise RuntimeError("Layer norm decomposition flow should have exactly one output node")
        predecessor_name = predecessors.pop()
        predecessor = layer_norm_flow.nodes[predecessor_name]["layer"]

        # add normalization layer between predecessor and output node
        eq_consumer_name = self._get_new_layer_name("equalization_consumer_out", layer_norm)

        eq_consumer = self.create_normalization_layer(
            eq_consumer_name,
            self._get_layer_output_shape(predecessor),
            precision_mode=precision_mode,
            bias_mode=bias_mode,
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        layer_norm_flow.add_node(eq_consumer_name, layer=eq_consumer)
        layer_norm_flow.add_edge(predecessor_name, eq_consumer_name)
        for output_node in output_nodes:
            attr = layer_norm_flow.edges[(predecessor_name, output_node)]
            layer_norm_flow.remove_edge(predecessor_name, output_node)
            layer_norm_flow.add_edge(eq_consumer_name, output_node, **attr)
        return eq_consumer_name

    def _find_rms_root_layer(self, layer_norm_flow: nx.DiGraph, layer_norm: HailoLayerNormalization):
        """
        Find the root layer of RMS norm.
        If the layer is RMS norm, the root layer is the input layer.
        Otherwise, if the layer is layer norm/group norm, the root layer is the output of the elementwise subtraction.

        Args:
            layer_norm_flow (nx.DiGraph): The layer normalization flow graph.
            layer_norm (HailoLayerNormalization): The Hailo layer normalization object.

        Returns:
            _type_: The root layer of RMS norm.
        """
        input_nodes = [node for node, degree in layer_norm_flow.in_degree if degree == 0]

        if len(input_nodes) != 1:
            raise RuntimeError("Layer norm decomposition flow should have exactly one input node")
        input_node = input_nodes[0]
        cfg = self.get_algo_config()
        if cfg.eq_consumer:
            input_node = list(layer_norm_flow.successors(input_node))[0]

        if layer_norm.rms_norm:
            rms_root = input_node
        else:
            # find ew_sub layer after input_node
            ew_sub_node = None
            for node in layer_norm_flow.successors(input_node):
                if isinstance(layer_norm_flow.nodes[node]["layer"], HailoElementwiseSub):
                    ew_sub_node = node
                    break
            if ew_sub_node is None:
                raise RuntimeError("Could not find ew_sub layer after input_node")
            rms_root = ew_sub_node

        rms_root_layer = layer_norm_flow.nodes[rms_root]["layer"]
        if rms_root_layer is None:
            output_shape = self._get_layer_output_shape(self._model.layers[rms_root])
        else:
            output_shape = self._get_layer_output_shape(rms_root_layer)

        return rms_root, output_shape

    def _insert_block(self, layer_norm_flow: nx.DiGraph, edges, new_output):
        """
        Insert block between multiple edges while retaining the original edge attributes.
        All edges must exit from the same node.
        Setting the input layer to the start of the block is not implemented in this function.

        Args:
            layer_norm_flow (nx.DiGraph): The layer normalization flow graph.
            edges: The edges to be replaced.
            new_output: new layer to be inserted between the edges, such that the edges will exit from the new layer.
        """
        if len(edges) == 0:  # nothing to do
            return
        if len({u for u, _ in edges}) != 1:
            raise AccelerasImplementationError("All edges must exit from the same node")
        custom_lossy_element = None
        for predessor, successor in edges:
            attr = layer_norm_flow.edges[(predessor, successor)]
            layer_norm_flow.remove_edge(predessor, successor)
            layer_norm_flow.add_edge(new_output, successor, **attr)

            # Keep the 12bit squared output behavior if external lossy element exists
            curr_custom_lossy_element = layer_norm_flow.nodes[successor]["layer"].input_lossy_element_external
            if custom_lossy_element is None and curr_custom_lossy_element is not None:
                custom_lossy_element = curr_custom_lossy_element
            elif (
                custom_lossy_element is not None
                and curr_custom_lossy_element is not None
                and custom_lossy_element != curr_custom_lossy_element
            ):
                raise AccelerasImplementationError("Custom lossy element mismatch")
        layer_norm_flow.nodes[new_output]["layer"].output_lossy_element_external = custom_lossy_element
        predessor = list(edges)[0][0]
        pred_layer = layer_norm_flow.nodes[predessor]["layer"]
        if pred_layer is not None:
            pred_layer.output_lossy_element_external = None

    def add_equalization_source(self, layer_norm_flow, layer_norm, precision_mode, bias_mode):
        eq_src_pred, input_shape = self._find_rms_root_layer(layer_norm_flow, layer_norm)

        eq_src_name = self._get_new_layer_name("equalization_source", layer_norm)
        eq_src = self.create_normalization_layer(
            eq_src_name,
            input_shape,
            precision_mode=precision_mode,
            bias_mode=bias_mode,
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        layer_norm_flow.add_node(eq_src_name, layer=eq_src)

        self._insert_block(layer_norm_flow, list(layer_norm_flow.out_edges(eq_src_pred)), eq_src_name)
        layer_norm_flow.add_edge(eq_src_pred, eq_src_name)

        return eq_src_name

    def add_online_token_equalization_to_flow(self, layer_norm_flow: nx.DiGraph, layer_norm: HailoLayerNormalization):
        precision_mode = layer_norm.get_precision_mode().reduce()
        if precision_mode != PrecisionMode.a16_w16:  # Only relevant for 16 bit
            return

        pred, input_shape = self._find_rms_root_layer(layer_norm_flow, layer_norm)

        groups = layer_norm.groups
        exp_decompose_name = self._get_new_layer_name("exp_decompose", layer_norm)
        reduce_max_name = self._get_new_layer_name("reduce_max1", layer_norm)

        exp_decompose = self.create_standalone_activation(
            exp_decompose_name,
            input_shape,
            activation="exp_decompose",
            precision_mode="a16_w16",
            bias_mode="double_scale_initialization",
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        layer_norm_flow.add_node(exp_decompose_name, layer=exp_decompose)

        reduce_max = self.create_reduce_max_layer(
            reduce_max_name,
            input_shape,
            reduce_axes=[3],
            groups=groups,
            precision_mode="a16_w16",
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        layer_norm_flow.add_node(reduce_max_name, layer=reduce_max)
        layer_norm_flow.add_edge(exp_decompose_name, reduce_max_name)

        if groups > 1:
            reduce_max_spatial_name = self._get_new_layer_name("reduce_max2", layer_norm)
            reduce_max_spatial = self.create_reduce_max_layer(
                reduce_max_spatial_name,
                self._get_layer_output_shape(reduce_max),
                reduce_axes=[1, 2],
                groups=1,
                precision_mode="a16_w16",
                original_names=layer_norm.hn_element.get("original_names", []),
            )
            layer_norm_flow.add_node(reduce_max_spatial_name, layer=reduce_max_spatial)
            layer_norm_flow.add_edge(reduce_max_name, reduce_max_spatial_name)
            last_node_name = reduce_max_spatial_name
            last_node = reduce_max_spatial
        else:
            last_node_name = reduce_max_name
            last_node = reduce_max

        shift_name = self._get_new_layer_name("shift", layer_norm)
        shift = self.create_ew_add_layer(
            shift_name,
            [input_shape, self._get_layer_output_shape(last_node)],
            activation="shift",
            precision_mode="a16_w16",
            bias_mode="double_scale_initialization",
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        layer_norm_flow.add_node(shift_name, layer=shift)
        layer_norm_flow.add_edge(last_node_name, shift_name, input_index=1)

        self._insert_block(layer_norm_flow, list(layer_norm_flow.out_edges(pred)), shift_name)
        layer_norm_flow.add_edge(pred, exp_decompose_name)
        layer_norm_flow.add_edge(pred, shift_name, input_index=0)

        self._online_token_equalization_info_by_layer[layer_norm.full_name] = OnlineTokenEqualizationBlock(
            exp_decompose_name, shift_name
        )

    def decompose_single_layer_norm(self, layer_norm: HailoLayerNormalization):
        """
        decompose single layer norm by its precision mode and other parameters
        """
        # this function will be called when we enable equalization of 8 bit
        cfg = self.get_algo_config()
        precision_mode = layer_norm.get_precision_mode().reduce()
        if precision_mode == PrecisionMode.a16_w16:
            bias_mode = BiasMode.single_scale_decomposition
        else:
            bias_mode = BiasMode.double_scale_initialization

        bit_decomposition_mode = (
            cfg.bit_decomposition_mode
            if precision_mode == PrecisionMode.a16_w16
            else LayerNormDecompositionMode.uniform_precision
        )

        is_rms_norm = layer_norm.rms_norm
        is_group_norm = layer_norm.groups > 1
        is_layer_norm = not is_rms_norm and not is_group_norm
        if is_group_norm and is_rms_norm:
            raise NotImplementedError("Group norm and RMS norm are not supported together")
        first_node = self._model.layers[self._model.flow.predecessors_sorted(layer_norm.full_name)[0]]

        layer_norm_flow = nx.DiGraph()
        layer_norm_flow.add_node(first_node.full_name, layer=None)
        # Note: the output index of the first node is not preserved.
        if first_node.num_outputs != 1:
            raise NotImplementedError("Layer norm decomposition doesn't work directly after multi output layer.")
        curr_node = first_node
        if cfg.eq_consumer:
            eq_consumer_name = self._get_new_layer_name("pre_ln_equalization_consumer", layer_norm)
            eq_consumer_node = self.create_normalization_layer(
                eq_consumer_name,
                self._get_layer_output_shape(first_node),
                precision_mode=precision_mode,
                bias_mode=bias_mode,
            )
            layer_norm_flow.add_edge(first_node.full_name, eq_consumer_name)
            layer_norm_flow.add_node(eq_consumer_name, layer=eq_consumer_node)
            curr_node = eq_consumer_node

        # region x-mu
        # Step 1: Add x-mu (if needed)
        if is_layer_norm or is_group_norm:
            # Type 1: Layer norm (x-mu)
            # Type 2: Group norm (x-mu with spatial mean and resize)
            # grouped kwarg determines the type
            curr_node = self._subtract_average_flow(
                curr_node,
                layer_norm_flow,
                layer_norm,
                precision_mode=precision_mode,
                bias_mode=bias_mode,
            )
        elif is_rms_norm:
            # Type 3: RMS norm (no x-mu)
            curr_node = curr_node
        else:
            raise AccelerasImplementationError("Unexpected layer norm type")
        # endregion

        numerator_node = curr_node

        if bit_decomposition_mode == LayerNormDecompositionMode.split_precision:
            # Type 1: decomposed precision
            curr_node, low, high = self._inverse_variance_flow_decomposed(curr_node, layer_norm_flow, layer_norm)
            # In this case - low, high are the numerator node
            curr_node = self._decomposed_mult_flow(curr_node, low, high, layer_norm_flow, layer_norm)
        elif precision_mode == PrecisionMode.a8_w8:
            # Type 2: normal 8bit
            curr_node = self._inverse_variance_flow_8bit(curr_node, layer_norm_flow, layer_norm)
            curr_node = self._ew_mult_flow(
                curr_node, numerator_node, layer_norm_flow, layer_norm, precision_mode, bias_mode
            )
        else:
            # Type 3: normal 16bit
            curr_node = self._inverse_variance_flow_16bit(curr_node, layer_norm_flow, layer_norm)
            curr_node = self._ew_mult_flow(
                curr_node,
                numerator_node,
                layer_norm_flow,
                layer_norm,
                precision_mode,
                BiasMode.double_scale_initialization,
            )

        last_node_names = self._model.flow.successors_sorted(layer_norm.full_name)
        for node_name in last_node_names:
            node = self._model.layers[node_name]
            layer_norm_flow.add_node(node.full_name, layer=None)
            edge_attr = self._model.flow.edges[(layer_norm.full_name, node.full_name)]
            layer_norm_flow.add_edge(curr_node.full_name, node.full_name, **edge_attr)
        return layer_norm_flow

    def _inverse_variance_flow_8bit(
        self,
        input_node: BaseHailoLayer,
        layer_norm_flow: nx.DiGraph,
        layer_norm: LayerNormDecompositionMode,
    ):
        prec_mode = PrecisionMode.a8_w8
        bias_mode = BiasMode.double_scale_initialization

        curr_node = self._square_flow(
            input_node,
            layer_norm_flow,
            layer_norm,
            precision_mode=prec_mode,
            bias_mode=bias_mode,
        )

        curr_node = self._fuse_mean_and_inv_sqrt_flow(
            curr_node,
            layer_norm_flow,
            layer_norm,
            prec_mode,
            bias_mode,
        )
        return curr_node

    def _inverse_variance_flow_16bit(
        self,
        input_node: BaseHailoLayer,
        layer_norm_flow: nx.DiGraph,
        layer_norm: HailoLayerNormalization,
    ):
        prec_mode = PrecisionMode.a16_w16
        bias_mode = BiasMode.single_scale_decomposition

        input_shape = self._get_layer_output_shape(input_node)

        channels = input_shape[-1]

        # if RMS norm or Layer norm, create groups and split the mean operation to 2 layers. the square will apply partial sum and the second will apply the rest.
        # if group norm, the square will apply partial sum of the features, the second will do the rest features of the groups, and the then will apply spatial mean will be applied
        epsilon = layer_norm.export_weights()["epsilon"]

        # From this point onward, group can either mean the layer_norm.groups or the number of groups per input for the "nudging" logic
        if tuple(layer_norm.reduce_axes) == (1, 2, 3) or layer_norm.groups > 1:
            activation = "linear"
            ch_mean_bias = 0
        else:
            activation = "inv_sqrt"
            ch_mean_bias = epsilon
        groups = input_shape[-1]
        group_size = channels // groups  # group size will be the size of the summed channels

        # Substep 1: Apply square and partial sum. square is done in 16 bit, and the sum is limited based on the stats and the accumulator capabilities.
        input_shape = self._get_layer_output_shape(input_node)
        square1_name = self._get_new_layer_name("square1", layer_norm)
        square1 = self.create_square_layer(
            square1_name,
            input_shape,
            precision_mode=prec_mode,
            bias_mode=BiasMode.double_scale_initialization,
            reduce_sum_groups=groups,  # will be modified after stats
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        layer_norm_flow.add_node(square1.full_name, layer=square1)
        layer_norm_flow.add_edge(input_node.full_name, square1.full_name)

        curr_node = square1

        # Substep 2: Additional sum and division with weighted conv layer
        mean_factor = 1 / group_size
        mean_name = self._get_new_layer_name("conv_var_inv", layer_norm)
        input_shape = self._get_layer_output_shape(curr_node)
        conv_mean = self.create_conv_mean_layer(
            mean_name,
            input_shape,
            factors=(mean_factor,),
            norm_groups=layer_norm.groups,
            bias=ch_mean_bias,
            activation=activation,
            precision_mode=prec_mode,
            bias_mode=bias_mode,
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        conv_mean.conv_op.force_rounded_shift_delta = True  # TODO - we may not need this anynorm
        layer_norm_flow.add_node(conv_mean.full_name, layer=conv_mean)
        layer_norm_flow.add_edge(curr_node.full_name, conv_mean.full_name)
        curr_node = conv_mean

        # Substep 3: (Optional) Apply spatial mean & resize
        if tuple(layer_norm.reduce_axes) == (1, 2, 3) or layer_norm.groups > 1:
            curr_node = self._spatial_mean_flow(
                curr_node,
                layer_norm_flow,
                layer_norm,
                prec_mode,
                bias_mode,
                suffix="spatial_variance",
                bias=epsilon,
                activation="inv_sqrt",
            )
        self._summed_squares.append((square1, conv_mean))
        return curr_node

    def _inverse_variance_flow_decomposed(
        self,
        input_node: BaseHailoLayer,
        layer_norm_flow: nx.DiGraph,
        layer_norm: LayerNormDecompositionMode,
    ):
        low, high = self._split_precision_flow(input_node, layer_norm_flow, layer_norm)
        curr_node, factors = self._decomposed_square_flow(low, high, layer_norm_flow, layer_norm)

        curr_node = self._fuse_mean_and_inv_sqrt_flow(
            curr_node,
            layer_norm_flow,
            layer_norm,
            PrecisionMode.a16_w16,
            BiasMode.single_scale_decomposition,
            factors,
        )
        return curr_node, low, high

    # region subflows of layer norm decompositions
    def _decomposed_mult_flow(
        self,
        input_node1: BaseHailoLayer,
        low: BaseHailoLayer,
        high: BaseHailoLayer,
        layer_norm_flow: nx.DiGraph,
        layer_norm: HailoLayerNormalization,
    ):
        if self._optimize_ew_mult:
            mock_kernel_values_mult_low_var = [2, 2]
            mock_kernel_values_mult_high_var = [2, 32]
        else:
            mock_kernel_values_mult_low_var = [2, 2]
            mock_kernel_values_mult_high_var = [2, 2]

        precision_change_pre_mult = self.create_precision_change_layer(
            self._get_new_layer_name("precision_change_pre_mult", layer_norm),
            self._get_layer_output_shape(input_node1),
            PrecisionMode.a16_w16_a8,
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        layer_norm_flow.add_node(precision_change_pre_mult.full_name, layer=precision_change_pre_mult)
        layer_norm_flow.add_edge(input_node1.full_name, precision_change_pre_mult.full_name)

        ew_mult_low = self.create_ew_mult_layer(
            self._get_new_layer_name("ew_mult_low_var", layer_norm),
            [
                self._get_layer_output_shape(precision_change_pre_mult),
                self._get_layer_output_shape(low),
            ],
            mock_kernel_values=mock_kernel_values_mult_low_var,
            precision_mode=PrecisionMode.a8_w8,
            bias_mode=BiasMode.double_scale_initialization,
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        layer_norm_flow.add_node(ew_mult_low.full_name, layer=ew_mult_low)
        layer_norm_flow.add_edge(precision_change_pre_mult.full_name, ew_mult_low.full_name, input_index=0)
        layer_norm_flow.add_edge(low.full_name, ew_mult_low.full_name, input_index=1)

        ew_mult_high = self.create_ew_mult_layer(
            self._get_new_layer_name("ew_mult_high_var", layer_norm),
            [
                self._get_layer_output_shape(precision_change_pre_mult),
                self._get_layer_output_shape(high),
            ],
            mock_kernel_values=mock_kernel_values_mult_high_var,
            precision_mode=PrecisionMode.a8_w8,
            bias_mode=BiasMode.double_scale_initialization,
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        layer_norm_flow.add_node(ew_mult_high.full_name, layer=ew_mult_high)
        layer_norm_flow.add_edge(precision_change_pre_mult.full_name, ew_mult_high.full_name, input_index=0)
        layer_norm_flow.add_edge(high.full_name, ew_mult_high.full_name, input_index=1)

        ew_add = self.create_ew_add_layer(
            self._get_new_layer_name("ew_add_out", layer_norm),
            [
                self._get_layer_output_shape(ew_mult_low),
                self._get_layer_output_shape(ew_mult_high),
            ],
            precision_mode=PrecisionMode.a8_w8_a16,
            bias_mode=BiasMode.double_scale_initialization,
            original_names=layer_norm.hn_element.get("original_names", []),
        )

        layer_norm_flow.add_node(ew_add.full_name, layer=ew_add)
        layer_norm_flow.add_edge(ew_mult_low.full_name, ew_add.full_name, input_index=0)
        layer_norm_flow.add_edge(ew_mult_high.full_name, ew_add.full_name, input_index=1)

        return ew_add

    def _ew_mult_flow(
        self,
        input_node1: BaseHailoLayer,
        input_node2: BaseHailoLayer,
        layer_norm_flow: nx.DiGraph,
        layer_norm: HailoLayerNormalization,
        precision_mode,
        bias_mode,
    ):
        input_shape1 = self._get_layer_output_shape(input_node1)
        input_shape2 = self._get_layer_output_shape(input_node2)
        ew_mult1 = self.create_ew_mult_layer(
            self._get_new_layer_name("ew_mult1", layer_norm),
            [input_shape1, input_shape2],
            precision_mode=precision_mode,
            bias_mode=bias_mode,
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        layer_norm_flow.add_node(ew_mult1.full_name, layer=ew_mult1)
        layer_norm_flow.add_edge(input_node1.full_name, ew_mult1.full_name, input_index=0)
        layer_norm_flow.add_edge(input_node2.full_name, ew_mult1.full_name, input_index=1)
        return ew_mult1

    def _fuse_mean_and_inv_sqrt_flow(
        self,
        input_node: BaseHailoLayer,
        layer_norm_flow: nx.DiGraph,
        layer_norm: HailoLayerNormalization,
        precision_mode,
        bias_mode,
        fusing_factors=(1,),
    ):
        # Non nudged version is not implemented. The previous version of the code wasn't tested on all use cases, and group_nudged was the fully supported version.
        # Type 1: Layer norm / RMS norm
        # Type 2: Group norm (with additional mean and resize)
        cfg = self.get_algo_config()
        input_shape = self._get_layer_output_shape(input_node)
        # Fused channels refers to the number of channels after the concatenation (if exists)
        # If no concatenation occured, the fused_channels is the same as real_channels and fusing_factors should be (1, )
        fused_channels = input_shape[-1]
        real_channels = fused_channels // len(fusing_factors)
        if fused_channels % len(fusing_factors) != 0:
            raise ValueError("Number of channels must be divisible by the number of factors")

        # if RMS norm or Layer norm, create groups and split the mean operation to 2 layers. the first will
        epsilon = layer_norm.export_weights()["epsilon"]

        # From this point onward, group can either mean the layer_norm.groups or the number of groups per input for the "nudging" logic
        if tuple(layer_norm.reduce_axes) == (1, 2, 3) or layer_norm.groups > 1:
            groups = layer_norm.groups
            activation = "linear"
            ch_mean_bias = 0
        elif not cfg.force_group_size_split:
            # This is a rough heuristic that helps with the performance of llama2_7b (that have 4096 channels per input)
            # note that for qwen has 2048 channels, and the groups used are 4. so the group size is 512, we tried to keep the ratio similar to qwen's
            divisors = self._get_divisors(real_channels)  # default range of divisors is 1-12
            # find divisor which divides closest to group_size_split (512 by default)
            idx = np.argmin(np.abs((real_channels / divisors) - cfg.group_size_split))
            groups = int(divisors[idx])

            activation = "inv_sqrt"
            ch_mean_bias = epsilon
        else:
            groups = real_channels // cfg.group_size_split
            if cfg.group_size_split * groups != real_channels:
                raise ValueError(
                    f"Number of channels must be divisible by the number of groups. "
                    f"group size: {cfg.group_size_split}, channels: {real_channels}, layer: {layer_norm.full_name}. "
                    f"Please change the group size or don't force group size split. "
                    f"{cfg.get_command()}({cfg.get_feature}, force_group_size_split=False) or "
                    f"{cfg.get_command()}({cfg.get_feature}, force_group_size_split=True, group_size_split=<valid divisor>)"
                )
        group_size = real_channels // groups  # group size will be the size of the summed channels

        if groups > 1 or len(fusing_factors) > 1:
            # Substep 1: partial sum ("nudged") with weighted conv
            # if groups is 1 and fusing_factors is 1, we don't need to sum the channels seperately
            partial_sum_name = self._get_new_layer_name("normalization_nudge", layer_norm)
            partial_sum = self.create_conv_sum_layer(
                partial_sum_name,
                input_shape,
                group_size,
                precision_mode=precision_mode,
                bias_mode=bias_mode,
                original_names=layer_norm.hn_element.get("original_names", []),
            )
            layer_norm_flow.add_node(partial_sum.full_name, layer=partial_sum)
            layer_norm_flow.add_edge(input_node.full_name, partial_sum.full_name)
            curr_node = partial_sum
            mean_factor = 1 / group_size
        else:
            curr_node = input_node
            mean_factor = 1

        # Substep 2: Additional sum and division with weighted conv layer

        mean_name = self._get_new_layer_name("conv_var_inv", layer_norm)
        input_shape = self._get_layer_output_shape(curr_node)
        conv_mean = self.create_conv_mean_layer(
            mean_name,
            input_shape,
            factors=[mean_factor * f for f in fusing_factors],
            norm_groups=layer_norm.groups,
            bias=ch_mean_bias,
            activation=activation,
            precision_mode=precision_mode,
            bias_mode=bias_mode,
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        conv_mean.conv_op.force_rounded_shift_delta = True  # TODO - we may not need this anynorm
        layer_norm_flow.add_node(conv_mean.full_name, layer=conv_mean)
        layer_norm_flow.add_edge(curr_node.full_name, conv_mean.full_name)
        curr_node = conv_mean

        # Substep 3: (Optional) Apply spatial mean & resize
        if tuple(layer_norm.reduce_axes) == (1, 2, 3) or layer_norm.groups > 1:
            curr_node = self._spatial_mean_flow(
                curr_node,
                layer_norm_flow,
                layer_norm,
                precision_mode,
                bias_mode,
                suffix="spatial_variance",
                bias=epsilon,
                activation="inv_sqrt",
            )
        return curr_node

    def _decomposed_square_flow(
        self,
        low: BaseHailoLayer,
        high: BaseHailoLayer,
        layer_norm_flow: nx.DiGraph,
        layer_norm: HailoLayerNormalization,
    ):
        """Add a decomposed version of the square layer to the flow.

        This function takes in two input layers, `low` and `high` (single 16bit input after precision split),
        and performs a decomposed square flow on them. It creates intermediate layers, applies precision changes,
        and sets the output scales for relevant layers. Finally, it concatenates the output of the intermediate layers
        and returns the concatenated layer and the output shape.

        Args:
            low (BaseHailoLayer): The low input layer.
            high (BaseHailoLayer): The high input layer.
            layer_norm_flow (nx.DiGraph): The layer normalization flow graph.
            layer_norm (HailoLayerNormalization): The Hailo layer normalization object.

        Returns:
            Tuple[BaseHailoLayer, Tuple[int, int, int]]: The concatenated layer and the output shape.
        """
        if self._optimize_ew_mult:
            mock_kernel_values_square_low = [2, 2]
            mock_kernel_values_square_high = [8, 8]
            mock_kernel_values_low_high_mult = [2, 8]
        else:
            mock_kernel_values_square_low = [2, 2]
            mock_kernel_values_square_high = [2, 2]
            mock_kernel_values_low_high_mult = [2, 2]

        square_low_name = self._get_new_layer_name("square_low", layer_norm)
        low_output_shape = self._get_layer_output_shape(low)
        square_low = self.create_square_layer(
            square_low_name,
            low_output_shape,
            mock_kernel_values=mock_kernel_values_square_low,
            precision_mode=PrecisionMode.a8_w8,
            bias_mode=BiasMode.double_scale_initialization,
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        layer_norm_flow.add_node(square_low.full_name, layer=square_low)
        layer_norm_flow.add_edge(low.full_name, square_low.full_name)

        prec_change_low_name = self._get_new_layer_name("precision_change_low", layer_norm)
        prec_change_low = self._precision_change_flow(
            square_low,
            layer_norm_flow,
            prec_change_low_name,
            PrecisionMode.a8_w8_a16,
            layer_norm,
        )

        square_high_name = self._get_new_layer_name("square_high", layer_norm)
        high_output_shape = self._get_layer_output_shape(high)
        square_high = self.create_square_layer(
            square_high_name,
            high_output_shape,
            mock_kernel_values=mock_kernel_values_square_high,
            precision_mode=PrecisionMode.a8_w8,
            bias_mode=BiasMode.double_scale_initialization,
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        layer_norm_flow.add_node(square_high.full_name, layer=square_high)
        layer_norm_flow.add_edge(high.full_name, square_high.full_name)

        prec_change_high_name = self._get_new_layer_name("precision_change_high", layer_norm)
        prec_change_high = self._precision_change_flow(
            square_high,
            layer_norm_flow,
            prec_change_high_name,
            PrecisionMode.a8_w8_a16,
            layer_norm,
        )

        mult_low_high_name = self._get_new_layer_name("ew_mult_low_high", layer_norm)
        mult_low_high = self.create_ew_mult_layer(
            mult_low_high_name,
            [low_output_shape, high_output_shape],
            mock_kernel_values=mock_kernel_values_low_high_mult,
            precision_mode=PrecisionMode.a8_w8,
            bias_mode=BiasMode.double_scale_initialization,
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        layer_norm_flow.add_node(mult_low_high.full_name, layer=mult_low_high)
        layer_norm_flow.add_edge(low.full_name, mult_low_high.full_name, input_index=0)
        layer_norm_flow.add_edge(high.full_name, mult_low_high.full_name, input_index=1)

        prec_change_low_high_name = self._get_new_layer_name("precision_change_low_high", layer_norm)
        prec_change_low_high = self._precision_change_flow(
            mult_low_high,
            layer_norm_flow,
            prec_change_low_high_name,
            PrecisionMode.a8_w8_a16,
            layer_norm,
        )

        # set the ratio between input scale and output scale for relevant layers
        #################
        HIGH_FACTOR = 0
        LOW_FACTOR = 6
        HIGH_LOW_FACTOR = 4

        square_low.forced_output_scale_scalar_dof = 2**LOW_FACTOR
        square_high.forced_output_scale_scalar_dof = 2**HIGH_FACTOR
        mult_low_high.forced_output_scale_scalar_dof = 2**HIGH_LOW_FACTOR

        prec_change_low.forced_output_scale_scalar_dof = 1
        prec_change_high.forced_output_scale_scalar_dof = 1
        prec_change_low_high.forced_output_scale_scalar_dof = 1
        #####################
        concat_square_name = self._get_new_layer_name("concat_layer", layer_norm)
        prec_change_low_shape = self._get_layer_output_shape(prec_change_low)
        prec_change_high_shape = self._get_layer_output_shape(prec_change_high)
        prec_change_low_high_shape = self._get_layer_output_shape(prec_change_low_high)

        concat_square = self.create_concat_layer(
            concat_square_name,
            [prec_change_low_shape, prec_change_high_shape, prec_change_low_high_shape],
            precision_mode=PrecisionMode.a16_w16,
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        layer_norm_flow.add_node(concat_square.full_name, layer=concat_square)
        layer_norm_flow.add_edge(prec_change_low.full_name, concat_square.full_name, input_index=0)
        layer_norm_flow.add_edge(prec_change_high.full_name, concat_square.full_name, input_index=1)
        layer_norm_flow.add_edge(prec_change_low_high.full_name, concat_square.full_name, input_index=2)

        return concat_square, (1, 1, 2)

    def _precision_change_flow(
        self,
        input_node: BaseHailoLayer,
        layer_norm_flow: nx.DiGraph,
        full_name: str,
        precision_mode: PrecisionMode,
        layer_norm: HailoLayerNormalization,
    ):
        """
        Adds a precision change to the flow after the given input node.

        Args:
            input_node (BaseHailoLayer): The input node for the precision change.
            layer_norm_flow (nx.DiGraph): The graph representing the layer normalization flow.
            full_name (str): The full name of the precision change layer.
            precision_mode (PrecisionMode): The precision mode for the precision change layer.

        Returns:
            prec_change (BaseHailoLayer): The precision change layer created.

        """
        mult_low_high_out_shape = self._get_layer_output_shape(input_node)

        prec_change = self.create_precision_change_layer(
            full_name,
            mult_low_high_out_shape,
            precision_mode,
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        layer_norm_flow.add_node(prec_change.full_name, layer=prec_change)
        layer_norm_flow.add_edge(input_node.full_name, prec_change.full_name)
        return prec_change

    def _split_precision_flow(
        self,
        input_node: BaseHailoLayer,
        layer_norm_flow: nx.DiGraph,
        layer_norm: HailoLayerNormalization,
    ):
        precision_split_name = self._get_new_layer_name("precision_split", layer_norm)
        input_shape = self._get_layer_output_shape(input_node)
        precision_split = self.create_precision_split_layer(
            precision_split_name,
            input_shape,
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        self._precision_split_layers.add(precision_split.full_name)

        if self._square_12_bit:
            lossy_12_bit = APUOutputQuantElement(bits=12)
            precision_split.input_lossy_element_external = lossy_12_bit
            eq_src_name = self.add_equalization_source(
                layer_norm_flow, layer_norm, PrecisionMode.a16_w16, BiasMode.single_scale_decomposition
            )
            input_node = layer_norm_flow.nodes[eq_src_name]["layer"]
            input_node.output_lossy_element_external = lossy_12_bit

        layer_norm_flow.add_node(precision_split.full_name, layer=precision_split)
        layer_norm_flow.add_edge(input_node.full_name, precision_split.full_name)

        low_name = self._get_new_layer_name("shortcut1_low", layer_norm)
        high_name = self._get_new_layer_name("shortcut2_high", layer_norm)
        output_shapes = self._get_layer_output_shape(precision_split)
        low = self.create_shortcut_layer(
            low_name,
            output_shapes[0],
            PrecisionMode.a8_w8,
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        high = self.create_shortcut_layer(
            high_name,
            output_shapes[1],
            PrecisionMode.a8_w8,
            original_names=layer_norm.hn_element.get("original_names", []),
        )

        layer_norm_flow.add_node(low.full_name, layer=low)
        layer_norm_flow.add_node(high.full_name, layer=high)
        layer_norm_flow.add_edge(precision_split.full_name, low.full_name, output_index=0)
        layer_norm_flow.add_edge(precision_split.full_name, high.full_name, output_index=1)

        return low, high

    def _square_flow(
        self,
        input_node: BaseHailoLayer,
        layer_norm_flow: nx.DiGraph,
        layer_norm: HailoLayerNormalization,
        precision_mode,
        bias_mode,
    ):
        input_shape = self._get_layer_output_shape(input_node)
        square1_name = self._get_new_layer_name("square1", layer_norm)
        square1 = self.create_square_layer(
            square1_name,
            input_shape,
            precision_mode=precision_mode,
            bias_mode=bias_mode,
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        layer_norm_flow.add_node(square1.full_name, layer=square1)
        layer_norm_flow.add_edge(input_node.full_name, square1.full_name)
        return square1

    def _subtract_average_flow(
        self,
        input_node: BaseHailoLayer,
        layer_norm_flow: nx.DiGraph,
        layer_norm: HailoLayerNormalization,
        precision_mode="a16_w16",
        bias_mode="single_scale_decomposition",
    ):
        """
        Add the subtract average block to the flow
        """
        # TODO: consider automatic name assignment
        reduce_mean1_name = self._get_new_layer_name("reduce_mean1", layer_norm)
        ew_sub1_name = self._get_new_layer_name("ew_sub1", layer_norm)

        flow_input_shape = self._get_layer_output_shape(input_node)
        bias_zero = np.array(0, dtype=np.float32)

        reduce_mean1 = self.create_reduce_mean_layer(
            reduce_mean1_name,
            flow_input_shape,
            reduce_axes=[3],
            bias=bias_zero,
            activation="linear",
            groups=layer_norm.groups,
            precision_mode=precision_mode,
            bias_mode=bias_mode,
            original_names=layer_norm.hn_element.get("original_names", []),
        )

        layer_norm_flow.add_node(reduce_mean1.full_name, layer=reduce_mean1)
        layer_norm_flow.add_edge(input_node.full_name, reduce_mean1.full_name)
        last_node = reduce_mean1

        if tuple(layer_norm.reduce_axes) == (1, 2, 3) or layer_norm.groups > 1:
            last_node = self._spatial_mean_flow(
                reduce_mean1,
                layer_norm_flow,
                layer_norm,
                precision_mode,
                bias_mode,
                suffix="spatial",
            )

        input_shape = self._get_layer_output_shape(last_node)

        ew_sub1 = self.create_ew_sub_layer(
            ew_sub1_name,
            [flow_input_shape, input_shape],
            precision_mode=precision_mode,
            bias_mode=bias_mode,
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        layer_norm_flow.add_node(ew_sub1.full_name, layer=ew_sub1)
        layer_norm_flow.add_edge(input_node.full_name, ew_sub1.full_name, input_index=0)
        layer_norm_flow.add_edge(last_node.full_name, ew_sub1.full_name, input_index=1)
        return ew_sub1

    def _spatial_mean_flow(
        self,
        input_node: BaseHailoLayer,
        layer_norm_flow: nx.DiGraph,
        layer_norm: HailoLayerNormalization,
        precision_mode,
        bias_mode,
        suffix,
        bias=None,
        activation="linear",
    ):
        reduce_mean1_name = self._get_new_layer_name(f"reduce_mean1_{suffix}", layer_norm)
        reduce_mean2_name = self._get_new_layer_name(f"reduce_mean2_{suffix}", layer_norm)
        resize1_name = self._get_new_layer_name(f"resize1_{suffix}", layer_norm)

        input_shape = self._get_layer_output_shape(input_node)

        add_two_spatial = np.prod(input_shape) > 256**2
        if add_two_spatial:
            first_reduce_axes = [2]
            first_activation = "linear"
            first_bias = None
        else:
            first_reduce_axes = [1, 2]
            first_activation = activation
            first_bias = bias
        reduce_mean1 = self.create_reduce_mean_layer(
            reduce_mean1_name,
            input_shape,
            reduce_axes=first_reduce_axes,
            bias=first_bias,
            activation=first_activation,
            bias_mode=bias_mode,
            precision_mode=precision_mode,
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        layer_norm_flow.add_node(reduce_mean1.full_name, layer=reduce_mean1)
        layer_norm_flow.add_edge(input_node.full_name, reduce_mean1.full_name)
        input_shape = self._get_layer_output_shape(reduce_mean1)
        last_reduce = reduce_mean1
        if add_two_spatial:
            reduce_mean2 = self.create_reduce_mean_layer(
                reduce_mean2_name,
                input_shape,
                reduce_axes=[1],
                bias=bias,
                activation=activation,
                bias_mode=bias_mode,
                precision_mode=precision_mode,
                original_names=layer_norm.hn_element.get("original_names", []),
            )
            layer_norm_flow.add_node(reduce_mean2.full_name, layer=reduce_mean2)
            layer_norm_flow.add_edge(reduce_mean1.full_name, reduce_mean2.full_name)
            input_shape = self._get_layer_output_shape(reduce_mean2)
            last_reduce = reduce_mean2

        input_shape = self._get_layer_output_shape(last_reduce)
        resize1 = self.create_resize_layer(
            resize1_name,
            input_shape,
            input_node.input_shape,
            precision_mode=precision_mode,
            channels=False,
            original_names=layer_norm.hn_element.get("original_names", []),
        )
        layer_norm_flow.add_node(resize1.full_name, layer=resize1)
        layer_norm_flow.add_edge(last_reduce.full_name, resize1.full_name)

        return resize1

    # endregion

    # region helper functions
    def _get_layer_output_shape(self, layer):
        if layer.num_inputs == 1:
            inp_shape = layer.input_shape
        else:
            inp_shape = layer.input_shapes
        output_shape = layer.compute_output_shape(inp_shape)
        if layer.num_outputs == 1:
            return self._hn_shape_prep(output_shape)
        else:
            return [self._hn_shape_prep(shape) for shape in output_shape]

    def _get_new_layer_name(self, layer_name, layer_norm):
        # Insignificant caching for layer names... (The cache is never used, because each name is used only once.)
        self._layer_norm_name_split_cache.setdefault(layer_norm.name, {})
        if layer_name not in self._layer_norm_name_split_cache[layer_norm.name]:
            block_name, layer_norm_name = self.get_block_and_layer_names(layer_norm.name)
            self._layer_norm_name_split_cache[layer_norm.name][layer_name] = (
                block_name,
                layer_norm_name,
            )

        block_name, layer_norm_name = self._layer_norm_name_split_cache[layer_norm.name][layer_name]
        model_name = layer_norm.full_name.split("/", 1)[0]
        return f"{model_name}/{block_name}{layer_name}_{layer_norm_name}"

    @staticmethod
    def _get_divisors(num, minval=1, maxval=12) -> NDArray[np.int32]:
        divisors = []
        for i in range(minval, maxval + 1):
            if num % i == 0:
                divisors.append(i)
        return np.array(divisors, dtype=np.int32)

    @staticmethod
    def _hn_shape_prep(shape):
        # Explicit python int conversion for jsonschema verification.
        return [-1, *[int(s) for s in shape[1:]]]

    @staticmethod
    def _create_elementwise_repeats(input_shapes):
        in1 = np.array(input_shapes[0][1:])
        in2 = np.array(input_shapes[1][1:])
        repeats1 = np.maximum(in1 // in2, 1)
        repeats2 = np.maximum(in2 // in1, 1)
        if np.any(in1 * repeats2 != in2 * repeats1):
            raise ValueError("Input shapes are not compatible for elementwise operation")
        # Explicit python int conversion for jsonschema verification.
        repeats1 = [int(i) for i in repeats1]
        repeats2 = [int(i) for i in repeats2]

        input_repeats = [list(repeats2), list(repeats1)]
        return input_repeats

    # endregion

    # region decompose default functions
    def _build_mercury(self, layer_norm):
        # mercury scheme
        hn_element = layer_norm.to_hn()
        weights = layer_norm.export_weights()
        layer_name = layer_norm.full_name
        new_layer = HailoLayerNormMercury.from_hn(lname=layer_name, hn_element=hn_element)
        new_layer.import_weights(weights)
        self._model._unlock_model()
        self._model.layers[layer_norm.full_name] = new_layer
        self._model._lock_model()
        self.add_config(new_layer, bias_mode="single_scale_decomposition", precision_mode="a8_w8")

    def _build_pluto(self, layer_norm):
        # pluto scheme
        hn_element = layer_norm.to_hn()
        weights = layer_norm.export_weights()
        layer_name = layer_norm.full_name
        new_layer = HailoLayerNorm.from_hn(lname=layer_name, hn_element=hn_element)
        new_layer.import_weights(weights)
        self._model._unlock_model()
        self._model.layers[layer_norm.full_name] = new_layer
        self._model._lock_model()
        self.add_config(new_layer, bias_mode="single_scale_decomposition", precision_mode="a16_w16")

    # endregion

    def _get_nudged_factors(self, factor_candidate):
        factors = []
        version = "version1"
        if version == "version1":
            max_number_of_2 = 4
            max_number_of_3 = 2
            max_number_of_5 = 0
        elif version == "version2":
            max_number_of_2 = 2
            max_number_of_3 = 2
            max_number_of_5 = 1
        else:
            max_number_of_2 = 7
            max_number_of_3 = 0
            max_number_of_5 = 0
        for num2 in range(max_number_of_2 + 1):
            for num3 in range(max_number_of_3 + 1):
                for num5 in range(max_number_of_5 + 1):
                    factors.append(2**num2 * 3**num3 * 5**num5)

        def find_closest_smaller_element(a, b):
            result = []
            for num_a in a:
                closest_smaller = max(filter(lambda x: x <= num_a, b), default=None)
                result.append(closest_smaller)
            return np.array(result)

        result = find_closest_smaller_element(factor_candidate, factors)
        max_factor = 2**max_number_of_2 * 3**max_number_of_3 * 5**max_number_of_5

        def get_lcm(nums):
            # calculates the LCM (Least Common Multiple) of a list of numbers.
            # Use the 'reduce' function to apply the 'lcm' function cumulatively to the elements of 'nums'.
            return reduce(lambda x, y: lcm(x, y), nums)

        def gcd(a, b):
            # 'gcd' - Greatest Common Divisor (GCD) of two numbers.
            # Use the Euclidean algorithm to find the GCD of 'a' and 'b'.
            while b:
                a, b = b, a % b
            return a

        def lcm(a, b):
            # 'lcm' - Least Common Multiple (LCM) of two numbers.
            # Calculate the LCM using the formula: LCM(a, b) = (a * b) / GCD(a, b).
            return a * b // gcd(a, b)

        _ = get_lcm(np.unique(result))

        return result, max_factor

    def _calibrate_summed_squares(
        self,
        square_layer: HailoFeatureMultiplierOnMac,
        mean_layer: HailoConv,
    ):
        cfg = self.get_algo_config()
        channels = square_layer.input_shape[-1]
        layer_norm_groups = mean_layer.output_shape[-1]
        if cfg.square_reduce_sum_groups is not None:
            desired_group_size = cfg.square_reduce_sum_groups
            if channels % desired_group_size != 0:
                raise ValueError(
                    f"Number of channels must be divisible by the number of elements in each group. "
                    f"group size: {desired_group_size}, channels: {channels}, layer: {square_layer.full_name}. "
                    f"Please change the group size or don't use square_reduce_sum_groups. "
                    f"{cfg.get_command()}({cfg.get_feature}, square_reduce_sum_groups=<valid divisor>) or "
                    f"remove the square_reduce_sum_groups parameter."
                )

            reduce_sum_groups = int(channels // desired_group_size)
            if reduce_sum_groups % layer_norm_groups != 0:
                raise ValueError(
                    f"Number of reduce_sum_groups must be divisible by the number of layer norm groups. "
                    f"reduce_sum_groups = channels//desired_group_size: {reduce_sum_groups}, layer_norm_groups: {layer_norm_groups}, layer: {square_layer.full_name}. "
                    f"Please change the group size or don't use square_reduce_sum_groups. "
                    f"{cfg.get_command()}({cfg.get_feature}, square_reduce_sum_groups=<valid divisor>) or "
                    f"remove the square_reduce_sum_groups parameter."
                )
        else:
            factor = cfg.sum_range_factor
            # Naive value of this factor should be 2:
            #   15 bit * 15 bits = 30 bits
            #   30 bits + 30 bits = 31 bits (which is the signed wraparound limit)
            # Realisticly, the mean value of the channels is lower, but it involved some wraparound risk.
            # Therefore, this value is configurable. # TODO: should it be configurable per layer?
            square_stats = square_layer.get_output_stats()[0]
            square_ranges = np.maximum(square_stats.max, 0) - np.minimum(square_stats.min, 0)
            max_range = np.max(square_ranges)

            # max_divisor = min(channels // layer_norm_groups, channels // 2)

            largest_divisor = int(
                channels // layer_norm_groups
            )  # Small optimization to avoid iteration over all divisors. The largest divisor is added at the end
            divisors = self._get_divisors(
                channels, 1, largest_divisor // 2
            )  # Small optimization to avoid iteration over all divisors. The largest divisor is added at the end

            divisors = np.append(divisors, largest_divisor).astype(np.int32)

            divisors = divisors[
                (channels // divisors) % layer_norm_groups == 0
            ]  # Note: if layer_norm_groups is 1, this case is redundant.

            desired_group_size = 1

            for div in divisors:
                tmp_reshape = np.reshape(square_ranges, [-1, div])
                tmp_sum = np.sum(tmp_reshape, axis=-1)
                if np.all(tmp_sum < max_range * factor):
                    desired_group_size = div
                else:
                    break

        reduce_sum_groups = int(channels // desired_group_size)

        if channels % reduce_sum_groups != 0 or reduce_sum_groups % layer_norm_groups != 0:
            # this sould not happen but a sanity check
            raise ValueError(
                f"Number of channels must be divisible by the number of reduce_sum_groups and number of reduce_sum_groups be divisible by the layer_norm_groups "
                f"reduce_sum_groups: {reduce_sum_groups}, channels: {channels}, layer_norm_groups {layer_norm_groups} layer: {square_layer.full_name}. "
                f"something is worng!!! the square_reduce_sum_groups parameter {divisors}."
            )

        new_square_layer = self.create_square_layer(
            square_layer.full_name,
            square_layer.input_shape,
            activation=square_layer.activation_atomic_op.act_name.value,
            reduce_sum_groups=reduce_sum_groups,
            precision_mode=square_layer.get_precision_mode().reduce().value,
            bias_mode=square_layer.get_bias_mode().value,
            original_names=square_layer.hn_element.get("original_names", []),
        )

        input_shape = self._get_layer_output_shape(new_square_layer)

        new_mean_layer = self.create_conv_mean_layer(
            mean_layer.full_name,
            input_shape,
            activation=mean_layer.activation_atomic_op.act_name.value,
            factors=(1 / desired_group_size,),
            norm_groups=int(mean_layer.output_shape[-1]),
            bias=mean_layer.bias,
            precision_mode=mean_layer.get_precision_mode().reduce().value,
            bias_mode=mean_layer.get_bias_mode().value,
            original_names=mean_layer.hn_element.get("original_names", []),
        )
        self._model.replace_layer(new_square_layer, square_layer)
        self._model.replace_layer(new_mean_layer, mean_layer)

    # region optimization Functions
    def equalize_layer(self, layer_norm_equiv_class: EquivClassNorm):
        source_layer = self._model.layers[layer_norm_equiv_class.source]
        consumer_square_layer = self._model.layers[layer_norm_equiv_class.consumer_square]
        consumer_out_layer = self._model.layers[layer_norm_equiv_class.consumer_out]

        equalization_factors = self.get_equalization_factors(source_layer)

        equalization_source_kernel, c1_factor = self._get_nudged_factors(equalization_factors)
        equalization_consumer_out_kernel = 1 / equalization_source_kernel
        equalization_consumer_square_kernel = equalization_consumer_out_kernel**2

        ####### set default scale values
        scale_source_candidate = 1
        scale_consumer_out_candidate = 1 / c1_factor
        scale_consumer_square = scale_consumer_out_candidate**2

        split_ratio = consumer_square_layer.input_shape[-1] // source_layer.input_shape[-1]

        max_desired_exponent = self._calc_max_accumulator_shift(
            equalization_consumer_square_kernel,
            scale_consumer_square,
            consumer_square_layer.conv_op.groups // split_ratio,
        )
        if max_desired_exponent > 23:
            # TODO: techincally we can remove the equalization layers in this case.
            return
        val = -(max_desired_exponent - 7)
        forced_output_factor_sqaure = 2**val

        equalization_source_kernel_q_candidate = equalization_source_kernel / scale_source_candidate
        equalization_consumer_out_kernel_q_candidte = equalization_consumer_out_kernel / scale_consumer_out_candidate

        scale_source, source_forced_output_factor = self._calc_needed_factor(
            source_layer,
            scale_source_candidate,
            equalization_source_kernel_q_candidate,
            factor_in=None,
            factor_out=equalization_source_kernel,
        )
        scale_consumer_out, consumer_forced_output_factor = self._calc_needed_factor(
            consumer_out_layer,
            scale_consumer_out_candidate,
            equalization_consumer_out_kernel_q_candidte,
            factor_in=equalization_source_kernel,
            factor_out=None,
        )

        ####################################################################

        self._import_new_kernel_q(
            source_layer,
            equalization_source_kernel,
            scale_source,
            source_forced_output_factor,
            split_ratio,
        )
        self._import_new_kernel_q(
            consumer_out_layer,
            equalization_consumer_out_kernel,
            scale_consumer_out,
            consumer_forced_output_factor,
            split_ratio,
        )
        self._import_new_kernel_q(
            consumer_square_layer,
            equalization_consumer_square_kernel,
            scale_consumer_square,
            forced_output_factor=forced_output_factor_sqaure,
            split_ratio=split_ratio,
        )

        lname = self._model.flow.successors_sorted(source_layer.full_name)[0]
        self._split_layer(self._model, lname, equalization_source_kernel)
        if lname in self._precision_split_layers:
            self._precision_split_layers.remove(lname)

        if DEBUG:
            equalization_source_kernel_q = equalization_source_kernel / scale_source
            equalization_consumer_out_kernel_q = equalization_consumer_out_kernel / scale_consumer_out
            equalization_consumer_square_kernel_q = equalization_consumer_square_kernel / scale_consumer_square
            max_accumulator_shift_floor = np.floor(max_desired_exponent)

            unique_sol = np.unique(equalization_source_kernel_q)
            diff = equalization_factors - equalization_source_kernel_q
            unique_sol = np.unique(equalization_source_kernel_q)
            unique_sol_out = np.unique(equalization_consumer_out_kernel_q)
            unique_sol_square = np.unique(equalization_consumer_square_kernel_q)

            print("diff", min(diff), max(diff))

            print("solution", unique_sol, len(unique_sol))
            print("unique_sol_out", unique_sol_out, len(unique_sol_out))
            print("unique_sol_square", unique_sol_square, len(unique_sol_square))

            expected_forced_output_factor = c1_factor**2
            print(
                f"max_accumulator_shift_floor {max_accumulator_shift_floor} "
                f"forced_output_factor_sqaure {forced_output_factor_sqaure} "
                f"expected_forced_output_factor {expected_forced_output_factor} "
                f"ratio {expected_forced_output_factor / forced_output_factor_sqaure}"
            )

    @staticmethod
    def get_equalization_factors(source_layer):
        """
        get equalization factors for the layer
        """
        input_stats = source_layer.get_input_stats()[0]
        epsilon = 1e-10

        stats_abs = np.maximum(np.maximum(np.abs(input_stats.min), np.abs(input_stats.max)), epsilon)
        max_all = np.max(stats_abs)

        factor_all = max_all / stats_abs

        return factor_all

    @staticmethod
    def _calc_max_accumulator_shift(equalization_consumer_square_kernel, scale_consumer_square, groups):
        equalization_consumer_square_kernel_q = equalization_consumer_square_kernel / scale_consumer_square
        val = equalization_consumer_square_kernel_q.reshape((-1, groups))
        sumation = np.sum(val, axis=0)
        log_2 = np.log2(sumation)
        max_accumulator_shift = np.max(log_2)
        return max_accumulator_shift

    def _calc_needed_factor(self, layer, kernel_scale_forced, kernel_q, factor_in=None, factor_out=None):
        """
        given the layer and the kernel_scale_forced and the equalization change  we want to calculate the expected output_scale and s
        et the new forced output factor (after nudgint)
        1. calc the expected output sacle - and nudge it if needed
        2. if there is needed shift- push it into the kernel scale forced

        """
        # get input_scales and output_scales after "equalization"
        input_scale, _ = layer.conv_op.calc_input_encoding_candidates(0, factor=factor_in)
        output_scale, _ = layer.output_op.calc_output_encoding_candidates(
            0,
            output_lossy_external=layer.output_lossy_element_external,
            factor=factor_out,
        )

        ###### nudge the output scale if needed
        acc_scale = input_scale[:: layer.kernel.shape[-2]] * kernel_scale_forced
        output_factors_val = np.array(acc_scale / output_scale, dtype=layer.act_op.FLOAT_TYPE_NP)
        nonzero = 1.0
        exponent_factors, mantissas_candidate, exponents = layer.act_op._get_mantissa_exponent_decomposition(
            np.array([nonzero], dtype=layer.act_op.FLOAT_TYPE_NP),
            output_factors_val,
        )
        mantissas = np.floor(mantissas_candidate)
        output_factor = np.squeeze(exponent_factors * mantissas * layer.act_op.final_shift_factor / nonzero)

        # push the shift into the kernel scale forced if needed
        act_op = layer.activation_atomic_op
        assigned_exp = act_op.get_assigned_exponent(-exponents)
        shift_fix_exp = np.max(-assigned_exp)
        shift_fix = np.max([0, np.max(shift_fix_exp)])

        max_shift = np.floor(np.log2((2**15 - 1) / np.max(kernel_q)))

        if shift_fix > max_shift:
            self.logger.info(f"the shift wnated to be {shift_fix} but is {max_shift}")
            shift_fix = max_shift

        forced_output_factor = output_factor[0] / (2.0**shift_fix)
        kernel_scale_forced /= 2**shift_fix

        return kernel_scale_forced, forced_output_factor

    @staticmethod
    def _split_layer(model, lname, factor=None):
        acceleras_layer = model.layers[lname]
        if isinstance(acceleras_layer, HailoPrecisionSplit):
            acceleras_layer.create_splits(factor=factor)

    def mask_online_token_equalization(self, online_token_equalization_block: OnlineTokenEqualizationBlock):
        MASK = np.array([2 ** (-6), 2 ** (-4), 2 ** (-2), 2 ** (-1), 2**0])

        exp_decompose = online_token_equalization_block.exp_decompose
        shift = online_token_equalization_block.shift

        # get input bits (signed so we need to substract 1)
        bits = self._model.layers[exp_decompose].get_input_lossy_elements()[0].bits - 1

        act_stats = self._model.layers[exp_decompose].act_op.get_input_stats(0)
        max_value = np.max(np.maximum(((2**bits) / (2**bits - 1)) * np.abs(act_stats.max), np.abs(act_stats.min)))
        min_value = -max_value

        # force ranges so that input/output are symetric, and exp_decompose is power of 2
        pred = self._model.flow.predecessors_sorted(exp_decompose)[0]
        self._force_range(pred, (((2**bits - 0.5) / (2**bits)) * min_value, ((2**bits - 1) / (2**bits)) * max_value))
        self._force_range(exp_decompose, (0, ((2 ** (bits + 1) - 1) / (2 ** (bits - 1))) * max_value))
        self._force_range(shift, (-1, ((2**bits - 1) / (2**bits))))

        mask = np.concatenate([MASK[::-1] * min_value, MASK * max_value])

        self._model.layers[exp_decompose].act_op.act_native_params["mask"] = mask
        self._model.layers[shift].act_op.act_native_params["mask"] = mask

    def _remove_equalization_on_input_layer_norm(self, layer_norm):
        # TODO: iterate backwards to find the source layers before the layer norm
        self._model_config.equalization.layers[self._model.flow.predecessors_sorted(layer_norm.full_name)[0]] = (
            LayerEqualizationConfig(policy="disabled")
        )

    @classmethod
    def _import_new_kernel_q(
        cls,
        layer,
        equalization_kernel,
        kernel_scale_forced=1,
        forced_output_factor=None,
        split_ratio=3,
    ):
        # this function creates the forced kernel we want
        kernel = layer.get_kernel().numpy()
        if layer.groups > 1:
            equalization_kernel_factor = np.tile(
                equalization_kernel.reshape((-1, kernel.shape[2])).transpose(), split_ratio
            )
        else:
            equalization_kernel_factor = np.expand_dims(equalization_kernel, 1)
        kernel_new = kernel * equalization_kernel_factor
        layer.conv_op.import_weights(kernel_new)

        kernel_q = kernel_new / kernel_scale_forced

        layer.conv_op.kernel_q_forced = kernel_q
        layer.conv_op.kernel_scale_forced = kernel_scale_forced
        layer.conv_op.kernel_scale_forced_to_save = True

        if forced_output_factor is not None:
            layer.forced_output_factor = forced_output_factor

    # endregion

    # region layer creation functions
    def init_common_layer(self, hn_element, full_name, weights=None) -> BaseHailoLayer:
        new_layer = gen_acceleras_layers_from_hn(full_name, hn_element, self.optimization_target)[full_name]
        if weights is None:
            weights = dict()
        new_layer.import_weights(weights)
        return new_layer

    def _fill_config(self, layer: BaseHailoLayer, precision_mode=None, bias_mode=None):
        full_name = layer.full_name
        self._model_config.equalization.layers[full_name] = LayerEqualizationConfig(policy="disabled")
        self._model_config.adaround.layers[full_name] = LayerAdaRoundConfig(policy="disabled")
        self._model_config.negative_exponent.layers[full_name] = LayerNegExponentConfig(rank=0)
        self._model_config.bias_correction.layers[full_name] = LayerBiasCorrectionConfig(policy="disabled")
        self._model_config.zero_static_channels.layers[full_name] = LayerZeroStaticChannelsConfig(
            policy="disabled",
        )
        default_prec_cfg = layer.get_default_precision_config()
        if precision_mode is None:
            precision_mode = default_prec_cfg.precision_mode
        if bias_mode is None:
            bias_mode = default_prec_cfg.bias_mode
        cfg = LayerPrecisionConfig(precision_mode=precision_mode, bias_mode=bias_mode, quantization_groups=1)
        self._model_config.precision_config.layers[full_name] = cfg
        layer.import_precision_config(cfg, self.optimization_target)

    def create_normalization_layer(
        self,
        full_name,
        input_shape,
        activation="linear",
        factor=1,
        precision_mode="a16_w16",
        bias_mode="single_scale_decomposition",
        original_names=None,
    ):
        input_channel = input_shape[-1]
        kernel_shape = [1, 1, input_channel, 1]
        original_names = original_names if original_names is not None else []

        hn_element = {
            "type": "normalization",
            "input_shapes": [self._hn_shape_prep(input_shape)],
            "original_names": original_names,
            "params": {
                "kernel_shape": kernel_shape,
                "strides": [1, 1, 1, 1],
                "dilations": [1, 1, 1, 1],
                "padding": "VALID",
                "groups": 1,
                "layer_disparity": 1,
                "input_disparity": 1,
                "batch_norm": False,
                "elementwise_add": False,
                "activation": activation,
            },
        }
        kernel = np.ones(input_channel, dtype=np.float32) * factor
        kernel = kernel.reshape(kernel_shape)
        bias = np.zeros([input_channel], dtype=np.float32)
        weights = {"kernel": kernel, "bias": bias}
        new_layer = self.init_common_layer(hn_element, full_name, weights)
        new_layer.trainable = False
        new_layer.bias_add_op.is_correctable = False
        self._fill_config(new_layer, precision_mode=precision_mode, bias_mode=bias_mode)

        return new_layer

    def create_conv_mean_layer(
        self,
        full_name,
        input_shape,
        factors=None,
        kernel=None,
        norm_groups=1,
        bias=None,
        activation="linear",
        precision_mode="a8_w8",
        bias_mode="double_scale_initialization",
        original_names=None,
    ):
        input_channel = input_shape[-1]
        kernel_shape = [1, 1, input_channel, norm_groups]
        original_names = original_names if original_names is not None else []

        hn_element = {
            "type": "conv",
            "input_shapes": [self._hn_shape_prep(input_shape)],
            "original_names": original_names,
            "compilation_params": {},
            "quantization_params": {},
            "params": {
                "kernel_shape": kernel_shape,
                "strides": [1, 1, 1, 1],
                "dilations": [1, 1, 1, 1],
                "padding": "VALID",
                "groups": 1,
                "layer_disparity": 1,
                "input_disparity": 1,
                "batch_norm": False,
                "elementwise_add": False,
                "activation": activation,
            },
        }
        kernels = []
        num_for_fact = input_channel // len(factors) // norm_groups

        for fact in factors:
            fact_to = fact / num_for_fact
            if norm_groups > 1:
                kernel = np.eye(norm_groups) * fact_to
                repeat_count = input_channel // (norm_groups * len(factors))
            else:
                kernel = fact_to
                repeat_count = num_for_fact
            kernel = np.repeat(kernel, repeat_count, axis=0)
            kernels.append(kernel)
        kernel = np.concatenate(kernels)
        kernel = kernel.reshape(kernel_shape)
        if bias is None:
            bias = 0
        bias = np.ones([norm_groups]) * bias
        weights = {"kernel": kernel, "bias": bias}

        new_layer = self.init_common_layer(hn_element, full_name, weights=weights)
        new_layer.trainable = False
        new_layer.bias_add_op.is_correctable = False

        self._fill_config(new_layer, precision_mode=precision_mode, bias_mode=bias_mode)
        return new_layer

    def create_conv_sum_layer(
        self,
        full_name,
        input_shape,
        channels_to_sum,
        bias=None,
        activation="linear",
        precision_mode="a8_w8",
        bias_mode="double_scale_initialization",
        original_names=None,
    ):
        input_channel = input_shape[-1]
        ### 6144 / 128= 48
        channels_out = input_channel // channels_to_sum
        kernel_shape_for_hn = [1, 1, input_channel, channels_out]
        original_names = original_names if original_names is not None else []

        hn_element = {
            "type": "conv",
            "input_shapes": [self._hn_shape_prep(input_shape)],
            "original_names": original_names,
            "params": {
                "kernel_shape": kernel_shape_for_hn,
                "strides": [1, 1, 1, 1],
                "dilations": [1, 1, 1, 1],
                "padding": "VALID",
                "groups": channels_out,
                "layer_disparity": 1,
                "input_disparity": 1,
                "batch_norm": False,
                "elementwise_add": False,
                "activation": activation,
            },
        }
        kernel_shape = [1, 1, channels_to_sum, channels_out]

        kernel = np.ones(shape=(1, 1, channels_to_sum, channels_out), dtype=np.float32)

        if DEBUG:
            print("kernel shape", kernel.shape, kernel_shape)
            print("input_shapes", input_shape)

            print("number of groups", channels_out)
            print("number_of_channels_to_sum", channels_to_sum)

        if bias is None:
            bias = 0

        bias = np.ones(shape=channels_out) * bias
        weights = {"kernel": kernel, "bias": bias}

        new_layer = self.init_common_layer(hn_element, full_name, weights=weights)
        new_layer.trainable = False
        new_layer.bias_add_op.is_correctable = False

        self._fill_config(new_layer, precision_mode=precision_mode, bias_mode=bias_mode)
        return new_layer

    def create_concat_layer(
        self,
        full_name,
        input_shapes,
        precision_mode,
        original_names=None,
    ):
        original_names = original_names if original_names is not None else []

        hn_element = {
            "type": "concat",
            "input_shapes": [self._hn_shape_prep(shape) for shape in input_shapes],
            "original_names": original_names,
            "params": {
                "concat_axis": "features",
            },
        }
        new_layer = self.init_common_layer(hn_element, full_name)
        new_layer.atomic_op.vector_zp = True
        self._fill_config(new_layer, precision_mode=precision_mode)
        return new_layer

    def create_reduce_mean_layer(
        self,
        full_name,
        input_shape,
        reduce_axes=None,
        bias=None,
        activation="linear",
        groups=1,
        precision_mode="a8_w8",
        bias_mode="double_scale_initialization",
        original_names=None,
    ):
        # TODO: Techincally, creation should be supported from init function, and not only from hn
        if reduce_axes is None:
            reduce_axes = [3]
        if tuple(reduce_axes) not in {(1, 2), (1,), (2,), (3,)}:
            raise ValueError("Reduce axes configurion is not supported")  # Why doesn't the hailo layer check it?
        original_names = original_names if original_names is not None else []

        hn_element = {
            "type": "reduce_mean",
            "input_shapes": [self._hn_shape_prep(input_shape)],
            "original_names": original_names,
            "params": {
                "reduce_axes": reduce_axes,
                "groups": groups,
                "activation": activation,
            },
        }
        weights = None
        if bias is not None:
            bias = np.ones([1]) * bias
            weights = {"bias": bias}
        new_layer = self.init_common_layer(hn_element, full_name, weights=weights)
        self._fill_config(new_layer, precision_mode=precision_mode, bias_mode=bias_mode)
        return new_layer

    def create_precision_change_layer(self, full_name, input_shape, precision_mode, original_names=None):
        precision_mode = PrecisionMode(precision_mode)
        if precision_mode not in [PrecisionMode.a8_w8_a16, PrecisionMode.a16_w16_a8]:
            raise ValueError(f"Precision mode {precision_mode} is not supported for precision change layer")
        return self.create_standalone_activation(
            full_name,
            input_shape,
            activation="linear",
            precision_mode=precision_mode,
            bias_mode=BiasMode.single_scale_decomposition,
            original_names=original_names,
        )

    def create_resize_layer(
        self,
        full_name,
        input_shape,
        output_shape,
        precision_mode="a8_w8",
        channels=True,
        original_names=None,
    ):
        bias_mode = "single_scale_decomposition"
        if channels:
            resize_h_ratio_list = 1.0
            resize_w_ratio_list = 1.0
            resize_f_ratio_list = float(output_shape[3] / input_shape[3])
            output_shape = [*input_shape[:3], output_shape[3]]
        else:  # Spatial resize
            resize_h_ratio_list = float(output_shape[1] / input_shape[1])
            resize_w_ratio_list = float(output_shape[2] / input_shape[2])
            resize_f_ratio_list = 1.0
            output_shape = [input_shape[0], *output_shape[1:3], input_shape[3]]
        original_names = original_names if original_names is not None else []

        hn_element = {
            "type": "resize",
            "input_shapes": [self._hn_shape_prep(input_shape)],
            "output_shapes": [self._hn_shape_prep(output_shape)],
            "original_names": original_names,
            "compilation_params": {
                "hw_layer_type_list": ["lcu"],
            },
            "params": {
                "resize_h_ratio_list": [resize_h_ratio_list],
                "resize_w_ratio_list": [resize_w_ratio_list],
                "resize_f_ratio_list": [resize_f_ratio_list],
                "method": "nearest_neighbor",
                "resize_bilinear_pixels_mode": "disabled",
            },
        }

        new_layer = self.init_common_layer(hn_element, full_name)
        self._fill_config(new_layer, precision_mode=precision_mode, bias_mode=bias_mode)
        return new_layer

    def create_ew_sub_layer(
        self,
        full_name,
        input_shapes,
        activation="linear",
        precision_mode="a8_w8",
        bias_mode="double_scale_initialization",
        original_names=None,
    ):
        input_repeats = self._create_elementwise_repeats(input_shapes)
        original_names = original_names if original_names is not None else []

        hn_element = {
            "type": "ew_sub",
            "input_shapes": [self._hn_shape_prep(shape) for shape in input_shapes],
            "original_names": original_names,
            "params": {
                "activation": activation,
                "input_repeats": input_repeats,
            },
        }
        new_layer = self.init_common_layer(hn_element, full_name)
        # add layer to model
        self._fill_config(new_layer, precision_mode=precision_mode, bias_mode=bias_mode)

        return new_layer

    def create_ew_add_layer(
        self,
        full_name,
        input_shapes,
        activation="linear",
        precision_mode="a8_w8",
        bias_mode="double_scale_initialization",
        original_names=None,
    ):
        input_repeats = self._create_elementwise_repeats(input_shapes)
        original_names = original_names if original_names is not None else []

        hn_element = {
            "type": "ew_add",
            "input_shapes": [self._hn_shape_prep(shape) for shape in input_shapes],
            "original_names": original_names,
            "params": {
                "activation": activation,
                "input_repeats": input_repeats,
            },
        }
        new_layer = self.init_common_layer(hn_element, full_name)
        self._fill_config(new_layer, precision_mode=precision_mode, bias_mode=bias_mode)
        return new_layer

    def create_ew_mult_layer(
        self,
        full_name,
        input_shapes,
        activation="linear",
        mock_kernel_values=None,
        precision_mode="a8_w8",
        bias_mode="double_scale_initialization",
        original_names=None,
    ):
        precision_mode = PrecisionMode(precision_mode)
        mult_type = EWMultType.on_mac if precision_mode == PrecisionMode.a16_w16 else EWMultType.on_apu
        input_repeats = self._create_elementwise_repeats(input_shapes)
        original_names = original_names if original_names is not None else []

        hn_element = {
            "type": "ew_mult",
            "input_shapes": [self._hn_shape_prep(shape) for shape in input_shapes],
            "original_names": original_names,
            "params": {
                "activation": activation,
                "ew_mult_type": mult_type,
                "input_repeats": input_repeats,
            },
        }
        new_layer = self.init_common_layer(hn_element, full_name)
        if precision_mode is PrecisionMode.a8_w8:
            if mock_kernel_values is None:
                mock_kernel_values = [2, 2]
            new_layer.mock_kernel_values = mock_kernel_values
        self._fill_config(new_layer, precision_mode=precision_mode, bias_mode=bias_mode)
        return new_layer

    def create_square_layer(
        self,
        full_name,
        input_shape,
        activation="linear",
        reduce_sum_groups=None,
        mock_kernel_values=None,
        precision_mode="a8_w8",
        bias_mode="double_scale_initialization",
        original_names=None,
    ):
        precision_mode = PrecisionMode(precision_mode)
        mult_type = "on_mac" if precision_mode == PrecisionMode.a16_w16 else "on_apu"
        input_shapes = [self._hn_shape_prep(input_shape)]
        output_shapes = copy(input_shapes)
        if mult_type != "on_mac" and reduce_sum_groups is not None:
            raise ValueError("Square layer doesn't support reduce sum groups if it is not on_mac")
        original_names = original_names if original_names is not None else []

        params = {
            "activation": activation,
            "feature_multiplier_type": "square",
            "ew_mult_type": mult_type,
        }
        if reduce_sum_groups is not None:
            params["reduce_sum_groups"] = reduce_sum_groups
            output_shapes = [[*shape[:3], reduce_sum_groups] for shape in output_shapes]

        hn_element = {
            "type": "feature_multiplier",
            "original_names": original_names,
            "input_shapes": input_shapes,
            "output_shapes": output_shapes,  # feature multiplier uses output shapes
            "params": params,
        }
        new_layer = self.init_common_layer(hn_element, full_name)
        if precision_mode == PrecisionMode.a8_w8:
            if mock_kernel_values is None:
                mock_kernel_values = [2, 2]
            new_layer.mock_kernel_values = mock_kernel_values
        self._fill_config(new_layer, precision_mode=precision_mode, bias_mode=bias_mode)
        return new_layer

    def create_precision_split_layer(
        self,
        full_name,
        input_shape,
        original_names=None,
    ):
        hn_element = {
            "type": "precision_splitter",
            "input_shapes": [self._hn_shape_prep(input_shape)],
            "original_names": original_names,
        }
        new_layer = self.init_common_layer(hn_element, full_name)
        self._fill_config(
            new_layer,
            precision_mode=PrecisionMode.a16_w16,
            bias_mode=BiasMode.single_scale_decomposition,
        )
        return new_layer

    def create_shortcut_layer(self, full_name, input_shape, precision_mode, original_names=None):
        original_names = original_names if original_names is not None else []
        hn_element = {
            "type": "shortcut",
            "input_shapes": [self._hn_shape_prep(input_shape)],
            "original_names": original_names,
        }
        new_layer = self.init_common_layer(hn_element, full_name)
        self._fill_config(
            new_layer,
            precision_mode=precision_mode,
            bias_mode=BiasMode.single_scale_decomposition,
        )
        return new_layer

    def create_standalone_activation(
        self,
        full_name,
        input_shape,
        activation="linear",
        precision_mode="a8_w8",
        bias_mode="single_scale_decomposition",
        original_names=None,
    ):
        original_names = original_names if original_names is not None else []
        hn_element = {
            "type": "activation",
            "input_shapes": [self._hn_shape_prep(input_shape)],
            "original_names": original_names,
            "params": {
                "activation": activation,
            },
        }
        new_layer = self.init_common_layer(hn_element, full_name)
        self._fill_config(
            new_layer,
            bias_mode=bias_mode,
            precision_mode=precision_mode,
        )
        return new_layer

    def create_reduce_max_layer(
        self,
        full_name,
        input_shape,
        reduce_axes=None,
        groups=1,
        precision_mode="a8_w8",
        original_names=None,
    ):
        if reduce_axes is None:
            reduce_axes = [3]
        if tuple(reduce_axes) not in {(1, 2), (1,), (2,), (3,)}:
            raise ValueError("Reduce axes configurion is not supported")  # Why doesn't the hailo layer check it?
        original_names = original_names if original_names is not None else []

        hn_element = {
            "type": "reduce_max",
            "input_shapes": [self._hn_shape_prep(input_shape)],
            "original_names": original_names,
            "params": {
                "reduce_axes": reduce_axes,
                "groups": groups,
            },
        }
        new_layer = self.init_common_layer(hn_element, full_name)
        self._fill_config(new_layer, precision_mode=precision_mode, bias_mode=BiasMode.single_scale_decomposition)
        return new_layer

    # endregion

    # region add generic function

    def add_config(
        self,
        new_layer,
        bias_mode="single_scale_decomposition",
        precision_mode="a16_w16",
    ):
        layer_name = new_layer.full_name
        cfg = LayerPrecisionConfig(precision_mode=precision_mode, bias_mode=bias_mode, quantization_groups=1)
        self._model_config.precision_config.layers[layer_name] = cfg
        new_layer.import_precision_config(cfg, self.optimization_target)

    def _force_range(self, lname, range):
        translation_layer_config = self.finalize_layer_cfg(self._model_config.translation_config.layers)
        self._model_config.translation_config.layers[lname] = translation_layer_config.get(
            lname, LayerTranslationConfig.get_default()
        )
        meta = self._model_config.translation_config.layers[lname].meta
        if meta is None:
            meta = dict()
        if "force_range_out" not in meta.keys():
            meta["force_range_out"] = CommandMeta(line=-1, command="", is_glob=False)
            self._model_config.translation_config.layers[lname].force_range_out = range
            self._model_config.translation_config.layers[lname].meta = meta

    # endregion
