from collections import Counter
from math import ceil, floor

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.utils.acceleras_definitions import ConcatAxis
from hailo_sdk_client.model_translator.fuser.definitions import FuserMode, FuserSubMode
from hailo_sdk_client.model_translator.fuser.exceptions import BackendFuserException
from hailo_sdk_client.model_translator.graph_lookup import (
    BwdChainNode,
    FwdChainNode,
    look_for_node,
)
from hailo_sdk_client.model_translator.translator import HailoNNBaseConverter
from hailo_sdk_client.post_fuser.algorithms import LayerNormMapping
from hailo_sdk_client.tools.fuser.fuser_helper import FuserHelper
from hailo_sdk_client.tools.layers.layers_utils import calculate_padding
from hailo_sdk_common.hailo_nn.hailo_nn import HailoNN
from hailo_sdk_common.hailo_nn.hn_definitions import (
    ActivationType,
    BlockType,
    DepthToSpaceType,
    FeatureMultiplierType,
    FormatConversionType,
    HnStage,
    LayerType,
    PaddingType,
    ResizeBilinearStreamingPolicy,
    ResizeMethod,
    SpaceToDepthType,
)
from hailo_sdk_common.hailo_nn.hn_layers import (
    ActivationLayer,
    ConcatLayer,
    ConstInputLayer,
    Conv2DLayer,
    DepthToSpaceLayer,
    EWAddLayer,
    EWDivLayer,
    EWMultLayer,
    EWSubLayer,
    ExternalPadLayer,
    FeatureMultiplierLayer,
    FeatureShuffleLayer,
    FeatureSplitterLayer,
    FormatConversionLayer,
    FusedBatchNormLayer,
    FusedConv2DLayer,
    FusedDenseLayer,
    FusedSliceLayer,
    FusedStandaloneActivationLayer,
    FusedStandaloneEWAddLayer,
    FusedStandaloneEWSubLayer,
    LayerNormalizationLayer,
    NormalizationLayer,
    PoolingLayer,
    ReduceL2Layer,
    ReduceMaxLayer,
    ReduceMeanLayer,
    ReduceMinLayer,
    ReduceSumLayer,
    RowSplitterLayer,
    ShortcutLayer,
    SliceLayer,
    SpaceToDepthLayer,
    SpatialSplitterLayer,
    WidthSplitterLayer,
)
from hailo_sdk_common.hailo_nn.hn_layers.layer_common import input_to_output_height_width
from hailo_sdk_common.hailo_nn.hn_layers.matmul import MatmulLayer
from hailo_sdk_common.logger.logger import default_logger


class HailoNNFuser(HailoNNBaseConverter):
    GROUP_CONV_MIN_NUM_OF_GROUPS = 3
    RESIZE_RATIO_NEUTRAL_VALUE = 1.0
    MAXIMUM_RESIZE_RATIO_PER_LAYER = 16.0
    LSTM_SEQ_LEN_WARNING_LIMIT = 16
    # TODO: change to False to keep group norm as an atomic op
    # TODO: https://hailotech.atlassian.net/browse/SDK-50254

    def __init__(self, graph, net_name, end_node_names):
        super().__init__(graph, end_node_names)
        self._index = 0
        self._logger = default_logger()
        self._net_name = net_name
        self._output_graph = None
        self._input_layers = None
        self._curr_input_layer = None
        self._curr_output_layer = None
        self._curr_num_successors = 0
        self._mode = FuserMode.outside
        self._sub_mode = FuserSubMode.none
        self._input_layers_to_output_layers = {}
        self._input_layers_with_edges = []
        self._fuser_helper = FuserHelper(self._graph)

    @property
    def output_graph(self):
        return self._output_graph

    def convert_model(self):
        # phase 1: run pre fusing function on parsed and pre-fused graph
        self._run_pre_fusing_flow()

        # phase 2: layer folding
        self._create_fused_model()

        # phase 3: perform post-fusing functions that are integral to the native model
        self._run_post_folding_optimizations()

        # phase 4: put final touches on the fused model (scope, indices, names)
        self._finalize_fused_model()
        return self._output_graph

    def _run_pre_fusing_flow(self):
        self._split_rnn_layers()
        self._split_lstm_layers()
        self._split_gru_layers()
        self._split_log_softmax_layers()
        self._split_ew_reduce_layers()
        self._split_reduce_min_layers()
        self._split_equal_layers()
        self._split_spatial_split_layers()
        self._fuser_helper.split_ew_add_n_layers()
        self._collapse_slice_chains()
        self._handle_strided_slices()
        self._handle_downsample_by_two_with_slice()
        self._convert_slices_to_feature_splitter()
        self._split_l2_normalization_layers()
        self._handle_reduce_sum_mean_layers()
        self._handle_interleaved_groups_reduce_layer()
        self._split_multi_layers_activations()
        self._handle_layer_norm()
        self._split_one_hot_layers()
        self._fuse_padding()
        self._apply_avgpool_correction()
        self._handle_group_convolutions()
        self._handle_disparity_resize()
        self._handle_null_resizes()
        self._handle_null_normalizations()
        self._handle_null_ew_const_input()
        self._split_scatter_nd_layers()
        self._handle_einsum_layers()
        self._handle_null_format_conversion()
        self._handle_null_concat()
        self._handle_null_slice()
        self._handle_spatial_flatten_before_dynamic_weights_layer()
        self._split_gcn_pooling_block()
        self._handle_hc_transpose()
        self._handle_pooling_ceil_mode()
        self._fuser_helper.handle_ew_div()
        self._fuser_helper.run_broadcast_ew()
        self._handle_conv16x16s16()
        self._update_resize_layers_methods()
        self._split_resize_layers()
        self._handle_avgpool1x1s1()
        self._handle_inv_sqrt_activation()
        self._handle_non_positive_range_matmul_transposed_input()
        self._handle_format_conversions()
        self._handle_neg_feature_shuffle()
        self._fuser_helper.replace_spatial_input_repeats_with_resize()  # TODO: remove after SDK-55542 is done

    def _run_post_folding_optimizations(self):
        self._fuse_output_format_conversion()
        self._handle_multiple_output_splits()
        self._handle_depth_to_space_activations()
        self._handle_deconv1x1s1()
        self._handle_conv1x1_after_global_avgpool()
        self._handle_flat_to_frames_before_resizes()
        # TODO: https://hailotech.atlassian.net/browse/SDK-45362
        # self._handle_dense_before_flat_to_frame()
        self._add_conv_before_dynamic_weights_layer()
        self._add_shortcut_to_empty_model()

    def _create_fused_model(self):
        self._output_graph = HailoNN()
        self._output_graph.net_params.stage = HnStage.PRE_FUSED
        self._output_graph.net_params.version = self._graph.net_params.version
        self._output_graph.name = self._net_name
        self._output_graph.detected_anchors = self._graph.detected_anchors
        self._output_graph.blocks = self._graph.blocks
        self._input_layers = self._graph.get_all_input_layers()
        self._index = 0
        self._curr_input_layer = self._input_layers[self._index]

        while self._index < len(self._input_layers):
            self._curr_num_successors = len(list(self._graph.successors(self._curr_input_layer)))
            self._curr_input_layer = self._input_layers[self._index]
            self._process_current_vertex()
            self._logger.debug(
                f"Fused layer Name={self._curr_input_layer.name}, Mode={self._mode}, Sub-mode={self._sub_mode}",
            )
            self._get_next_vertex()
            self._index += 1
            self._input_layers_to_output_layers[self._curr_input_layer] = self._curr_output_layer
            self._curr_output_layer.add_fused_index(self._curr_input_layer.index)
            self._curr_output_layer.move_params(self._curr_input_layer)
            # if current input layer is fused to another layer, we also need to transfer it's fused indices
            if self._curr_output_layer.index != self._curr_input_layer.index:
                for index in self._curr_input_layer.fused_indices:
                    if index not in self._curr_output_layer.fused_indices:
                        self._curr_output_layer.add_fused_index(index)
                        self._curr_output_layer.move_params(self._graph.layers_by_index[index])
                self._curr_output_layer.output_indices = self._curr_input_layer.output_indices.copy()
                self._curr_output_layer.outputs = self._curr_input_layer.outputs.copy()

        self._add_layers_connections()
        self._update_io_indices()
        self._output_graph.update_input_lists()
        self._fuser_helper.model = self._output_graph

    def _process_current_vertex(self):
        can_fuse_vertex = self._can_fuse_vertex()

        if self._curr_input_layer.op in [LayerType.input_layer, LayerType.const_input]:
            self._mode = FuserMode.outside
            self._sub_mode = FuserSubMode.none
            self._add_layer(has_edge=False)

        elif self._curr_input_layer.op in [LayerType.base_conv, LayerType.base_dw, LayerType.base_deconv]:
            self._mode = FuserMode.conv
            self._sub_mode = FuserSubMode.op
            self._add_layer(FusedConv2DLayer)

        elif self._curr_input_layer.op == LayerType.base_dense:
            self._mode = FuserMode.dense
            self._sub_mode = FuserSubMode.op
            self._add_layer(FusedDenseLayer)

        elif self._curr_input_layer.op in [LayerType.avgpool, LayerType.global_avg_pool, LayerType.maxpool]:
            self._mode = FuserMode.pool
            self._sub_mode = FuserSubMode.none
            self._add_layer()

        elif self._curr_input_layer.op == LayerType.ew_mult:
            self._mode = FuserMode.ew_mult
            self._sub_mode = FuserSubMode.op
            self._add_layer()

        elif self._curr_input_layer.op == LayerType.matmul:
            self._mode = FuserMode.activation_fusible
            self._sub_mode = FuserSubMode.none
            self._add_layer()

        elif self._curr_input_layer.op in [
            LayerType.output_layer,
            LayerType.concat,
            LayerType.resize,
            LayerType.format_conversion,
            LayerType.reduce_max,
            LayerType.space_to_depth,
            LayerType.feature_shuffle,
            LayerType.feature_splitter,
            LayerType.demux,
            LayerType.row_splitter,
            LayerType.argmax,
            LayerType.softmax,
            LayerType.feature_multiplier,
            LayerType.external_pad,
            LayerType.shortcut,
            LayerType.reduce_mean,
            LayerType.layer_normalization,
        ]:
            self._mode = FuserMode.outside
            self._sub_mode = FuserSubMode.none
            self._add_layer()

        elif self._curr_input_layer.op == LayerType.null:
            self._mode = FuserMode.outside
            self._sub_mode = FuserSubMode.none
            self._add_layer(ShortcutLayer)

        elif self._curr_input_layer.op == LayerType.depth_to_space:
            self._mode = FuserMode.depth_to_space
            self._sub_mode = FuserSubMode.none
            self._add_layer()

        elif self._curr_input_layer.op == LayerType.reduce_sum:
            self._mode = FuserMode.reduce_sum
            self._sub_mode = FuserSubMode.op
            self._add_layer()

        elif self._curr_input_layer.op in [LayerType.base_ew_add, LayerType.base_ew_sub]:
            self._mode = FuserMode.ew_add
            self._sub_mode = FuserSubMode.op
            if self._curr_input_layer.op == LayerType.base_ew_add:
                layer_class = FusedStandaloneEWAddLayer
            else:
                layer_class = FusedStandaloneEWSubLayer
            self._add_layer(layer_class)

        elif self._curr_input_layer.op == LayerType.base_slice:
            self._mode = FuserMode.outside
            self._sub_mode = FuserSubMode.none
            self._add_layer(FusedSliceLayer)

        elif self._curr_input_layer.op == LayerType.base_batch_norm:
            self._mode = FuserMode.bn
            self._sub_mode = FuserSubMode.op
            self._add_layer(FusedBatchNormLayer)

        elif self._curr_input_layer.op == LayerType.normalization:
            self._mode = FuserMode.normalization
            self._sub_mode = FuserSubMode.none
            self._add_layer()

        elif self._curr_input_layer.op == LayerType.width_splitter:
            self._mode = FuserMode.outside
            self._sub_mode = FuserSubMode.none
            self._add_layer()

        elif self._curr_input_layer.op == LayerType.bias_add:
            if self._mode not in [FuserMode.conv, FuserMode.dense] or self._sub_mode in [
                FuserSubMode.activation,
                FuserSubMode.bn,
            ]:
                raise BackendFuserException(
                    f"Unexpected fuser input: bias add layer {self._curr_input_layer.name} "
                    f"can't be fused to layer {self._curr_output_layer.name} (of type "
                    f"{self._curr_output_layer.op}), in mode={self._mode}, "
                    f"sub-mode={self._sub_mode}. (translated from "
                    f"[{', '.join(self._curr_input_layer.original_names)}]).",
                )

            self._sub_mode = FuserSubMode.bias_add

        elif self._curr_input_layer.op == LayerType.base_activation:
            if (
                self._mode not in [FuserMode.outside, FuserMode.depth_to_space]
                and self._sub_mode != FuserSubMode.activation
                and not (self._mode == FuserMode.pool and self._curr_output_layer.op == LayerType.maxpool)
                and self._curr_num_successors == 1
                and can_fuse_vertex
                and self._curr_output_layer.activation == ActivationType.linear
            ):
                self._sub_mode = FuserSubMode.activation
                self._curr_output_layer.activation = self._curr_input_layer.activation
            else:
                self._mode = FuserMode.outside
                self._sub_mode = FuserSubMode.none
                self._add_layer(FusedStandaloneActivationLayer)

        else:
            raise BackendFuserException(
                f'Unexpected fuser input: layer {self._curr_input_layer.name} of type '
                f'{self._curr_input_layer.op} in mode={self._mode}, '
                f'sub-mode={self._sub_mode}. (translated from '
                f'[{", ".join(self._curr_input_layer.original_names)}]).',
            )

    def _can_fuse_vertex(self):
        if self._curr_output_layer is not None:
            for succ in list(self._graph.successors(self._curr_input_layer)):
                for input_index in succ.input_indices:
                    if input_index in self._curr_output_layer.fused_indices:
                        return False

        return True

    def _add_layer(self, layer_class=None, has_edge=True):
        if layer_class is None:
            layer_class = type(self._curr_input_layer)
        self._curr_output_layer = layer_class.from_layer(self._curr_input_layer)
        self._output_graph.add_node(self._curr_output_layer)
        if has_edge:
            self._input_layers_with_edges.append(self._curr_input_layer)

    def _should_skip_neighbor(self, neighbor):
        if neighbor in self._input_layers:
            self._logger.debug(f"layer name={neighbor.name}, skipped since it is already in queue")
            return True

        return False

    def _get_next_vertex(self):
        succs = list(self._graph.successors(self._curr_input_layer))
        for neighbor in succs:
            self._logger.debug(f"Checking neighbor={neighbor.name}, type={neighbor.op}")
            if not self._should_skip_neighbor(neighbor):
                if neighbor.op in [
                    LayerType.base_conv,
                    LayerType.base_dw,
                    LayerType.base_deconv,
                    LayerType.base_dense,
                    LayerType.base_ew_add,
                    LayerType.base_ew_sub,
                    LayerType.maxpool,
                    LayerType.avgpool,
                    LayerType.global_avg_pool,
                    LayerType.resize,
                    LayerType.depth_to_space,
                    LayerType.feature_shuffle,
                    LayerType.feature_splitter,
                    LayerType.argmax,
                    LayerType.softmax,
                    LayerType.format_conversion,
                    LayerType.ew_mult,
                    LayerType.reduce_sum,
                    LayerType.reduce_max,
                    LayerType.row_splitter,
                    LayerType.space_to_depth,
                    LayerType.matmul,
                    LayerType.external_pad,
                    LayerType.feature_multiplier,
                    LayerType.const_input,
                ]:
                    self._logger.debug(f"Neighbor {neighbor.name} of type {neighbor.op} added to queue")
                    self._input_layers.append(neighbor)
                else:
                    self._logger.debug(f"Neighbor {neighbor.name} of type {neighbor.op} inserted to head of queue")
                    self._input_layers.insert(self._index + 1, neighbor)

    def _add_layers_connections(self):
        for layer in self._input_layers_with_edges:
            self._add_layer_connections(layer)

    def _add_layer_connections(self, layer):
        preds = list(self._graph.predecessors(layer))
        for pred in preds:
            self._logger.debug(f"Trying to connect {layer.name} and {pred.name}")
            if pred in self._input_layers_to_output_layers:
                current_output_layer = self._input_layers_to_output_layers[layer]
                prev_output_layer = self._input_layers_to_output_layers[pred]
                if current_output_layer != prev_output_layer:
                    self._output_graph.add_edge(prev_output_layer, current_output_layer)
                    self._logger.debug(f"Connecting {layer.name} and {pred.name}")

    def _update_io_indices(self):
        self._logger.debug("Updating input indices")
        for layer in list(self._output_graph):
            self._update_layer_io_indices(layer)
        self._logger.debug("Finished updating input indices")

    def _update_layer_io_indices(self, layer):
        self._logger.debug("Replacing input indices")
        for input_index in layer.input_indices:
            fused_index = self._find_fused_index(input_index, layer)
            if fused_index != input_index:
                self._logger.debug(f"Replacing input index {input_index} with {fused_index}")
                layer.replace_input_index(input_index, fused_index)
                layer.replace_input_layer(
                    self._graph.get_layer_name_by_index(input_index),
                    self._graph.get_layer_name_by_index(fused_index),
                )

        self._logger.debug("Replacing output indices")
        for output_index in layer.output_indices:
            fused_index = self._find_fused_index(output_index, layer)
            if fused_index != output_index:
                self._logger.debug(f"Replacing output index {output_index} with {fused_index}")
                layer.replace_output_index(output_index, fused_index)
                layer.replace_output_layer(
                    self._graph.get_layer_name_by_index(output_index),
                    self._graph.get_layer_name_by_index(fused_index),
                )

    def _find_fused_index(self, original_index, original_layer):
        for layer in list(self._output_graph):
            for index in layer.fused_indices:
                if original_index == index:
                    self._logger.debug(f"Found fused index! layer: {layer.name}, index: {layer.index}")
                    return layer.index

        raise BackendFuserException(
            f"Couldn't find a fused index for {original_index} while updating IO indices in "
            f"layer {original_layer.name}. (translated from "
            f"[{', '.join(original_layer.original_names)}]).",
        )

    def _finalize_fused_model(self):
        self._output_graph.set_names_and_indices(force=True)
        self._output_graph.update_output_indices()
        self._output_graph.add_scopes()
        self._output_graph.update_output_layers_order(self._end_node_names)
        self._output_graph.calculate_shapes()
        self._output_graph.validate_shapes()
        self._output_graph.update_detected_anchors_info()
        self._output_graph.extract_parsing_report_blocks()
        self._output_graph.net_params.stage = HnStage.HN

    def _split_reduce_min_layers(self):
        """
        This function replaces ReduceMin layer with the sequence Normalization->ReduceMax->Normalization
        We deal with reduce min as it was "negative reduce max": reduce_min(x) == -(reduce_max(-x))
        """
        new_layers = []
        layers_to_remove = []

        for layer in list(self._graph):
            if layer.op == LayerType.reduce_min:
                pred = next(iter(self._graph.predecessors(layer)))
                succs = list(self._graph.successors(layer))
                base_index = self._graph.get_next_index()

                norm_in = self._fuser_helper.create_layer(
                    NormalizationLayer,
                    base_index,
                    "normalization_in",
                    layer,
                    new_layers,
                    [layer.input_shape],
                )
                norm_in.mean = [0]
                norm_in.std = [-1]

                reduce_max = self._fuser_helper.create_layer(
                    ReduceMaxLayer,
                    base_index + 1,
                    f"reduce_max_{layer.name}",
                    layer,
                    new_layers,
                    [layer.output_shape],
                )

                norm_out = self._fuser_helper.create_layer(
                    NormalizationLayer,
                    base_index + 2,
                    "normalization_out",
                    layer,
                    new_layers,
                    [layer.output_shape],
                )
                norm_out.mean = [0]
                norm_out.std = [-1]

                self._fuser_helper.add_preds(norm_in, [pred])
                self._fuser_helper.add_preds(reduce_max, [norm_in])
                self._fuser_helper.add_preds(norm_out, [reduce_max])
                for succ in succs:
                    self._fuser_helper.replace_pred(succ, layer, norm_out)
                self._fuser_helper.replace_succ(pred, layer, norm_in)
                self._fuser_helper.add_succs(norm_in, [reduce_max])
                self._fuser_helper.add_succs(reduce_max, [norm_out])
                self._fuser_helper.add_succs(norm_out, succs)

                layers_to_remove.append(layer)
                self._logger.debug(f"Replaced {layer.full_name_msg} with Norm-ReduceMax-Norm " "layers.")

        # finalize graph manipulation outside graph iteration
        for layer in layers_to_remove:
            self._graph.remove_layer(layer)

    def _split_rnn_layers(self):
        new_layers = []
        layers_to_remove = []
        successors_meta_data = {}

        for layer in list(self._graph):
            if layer.op == LayerType.rnn:
                curr_index = self._graph.get_next_index()
                seq_len = layer.input_width
                block_info = (BlockType.RNN, layer.name)

                w_conv = Conv2DLayer.from_layer(layer)
                w_conv.index = curr_index
                curr_index += 1
                w_conv.name = f"conv_w_{layer.name}"
                w_conv.op = LayerType.base_conv
                w_conv.kernel = layer.kernel
                w_conv.bias = layer.bias
                w_conv.strides = [1, 1, 1, 1]
                w_conv.dilations = [1, 1, 1, 1]
                w_conv.padding = PaddingType.valid
                w_conv.block_info = block_info
                self._graph.add_node(w_conv)
                new_layers.append(w_conv)

                normalization = NormalizationLayer.from_layer(layer)
                normalization.index = curr_index
                curr_index += 1
                normalization.name = f"normalization_h0_{layer.name}"
                if layer.initial_h is not None:
                    initial_h = layer.initial_h
                else:
                    initial_h = np.zeros((1, 1, 1, w_conv.output_features))
                kernel_initializer = tf.keras.initializers.Constant(layer.recurrent_kernel)
                if layer.recurrent_bias is not None:
                    bias_initializer = tf.keras.initializers.Constant(layer.recurrent_bias)
                else:
                    bias_initializer = "zeros"
                conv_h = tf.keras.layers.Conv2D(
                    filters=w_conv.output_features,
                    kernel_size=(1, 1),
                    kernel_initializer=kernel_initializer,
                    bias_initializer=bias_initializer,
                )
                normalization.mean = np.array(-conv_h(initial_h)).flatten().tolist()
                normalization.std = [1] * w_conv.output_features
                self._graph.add_node(normalization)
                new_layers.append(normalization)
                normalization.block_info = block_info

                slices = []
                tanhs = []
                for i in range(1, seq_len + 1):
                    slice_layer = SliceLayer.from_layer(layer)
                    slice_layer.index = curr_index
                    curr_index += 1
                    slice_layer.name = f"slice_{i}_{layer.name}"
                    slice_layer.height_slice = [0, w_conv.output_height, 1]
                    slice_layer.width_slice = [i - 1, i, 1]
                    slice_layer.features_slice = [0, w_conv.output_features, 1]
                    slice_layer.input_shapes = [w_conv.output_shape]
                    slice_output_shape = slice_layer.input_shape.copy()
                    slice_output_shape[2] = 1
                    slice_layer.output_shapes = [slice_output_shape]
                    slices.append(slice_layer)
                    self._graph.add_node(slice_layer)
                    new_layers.append(slice_layer)
                    slice_layer.block_info = block_info

                    tanh = ActivationLayer.from_layer(layer)
                    tanh.index = curr_index
                    curr_index += 1
                    tanh.name = f"activation_{i}_{layer.name}"
                    tanh.activation = ActivationType.tanh
                    tanhs.append(tanh)
                    self._graph.add_node(tanh)
                    new_layers.append(tanh)
                    tanh.block_info = block_info

                r_convs = []
                ew_adds = []
                for i in range(2, seq_len + 1):
                    r_conv = Conv2DLayer.from_layer(layer)
                    r_conv.index = curr_index
                    curr_index += 1
                    r_conv.name = f"conv_r_{i}_{layer.name}"
                    r_conv.op = LayerType.base_conv
                    r_conv.kernel = layer.recurrent_kernel
                    r_conv.bias = layer.recurrent_bias
                    r_conv.kernel_shape = r_conv.kernel.shape
                    r_conv.strides = [1, 1, 1, 1]
                    r_conv.dilations = [1, 1, 1, 1]
                    r_conv.padding = PaddingType.valid
                    r_convs.append(r_conv)
                    self._graph.add_node(r_conv)
                    new_layers.append(r_conv)
                    r_conv.block_info = block_info

                    ew_add = EWAddLayer.from_layer(layer)
                    ew_add.index = curr_index
                    curr_index += 1
                    ew_add.name = f"ew_add_{i}_{layer.name}"
                    ew_adds.append(ew_add)
                    self._graph.add_node(ew_add)
                    new_layers.append(ew_add)
                    ew_add.block_info = block_info

                concat = ConcatLayer.from_layer(layer)
                concat.axis = ConcatAxis.spatial_w
                concat.index = curr_index
                curr_index += 1
                concat.name = f"concat_{layer.name}"
                self._graph.add_node(concat)
                new_layers.append(concat)
                concat.block_info = block_info

                w_conv.outputs = []
                w_conv.output_indices = []
                w_conv.output_shapes = []
                for slice_layer in slices:
                    w_conv.append_output_layer(slice_layer.name)
                    w_conv.append_output_index(slice_layer.index)
                    w_conv.append_output_shape(slice_layer.input_shape)
                    slice_layer.inputs = [w_conv.name]
                    slice_layer.input_indices = [w_conv.index]
                    self._graph.add_edge(w_conv, slice_layer)

                first_slice = slices[0]
                first_slice.outputs = [normalization.name]
                first_slice.output_indices = [normalization.index]
                normalization.input_shapes = [first_slice.output_shape]
                normalization.output_shapes = [normalization.input_shape]
                normalization.inputs = [first_slice.name]
                normalization.input_indices = [first_slice.index]
                self._graph.add_edge(first_slice, normalization)

                first_tanh = tanhs[0]
                first_tanh.inputs = [normalization.name]
                first_tanh.input_indices = [normalization.index]
                first_tanh.input_shapes = [normalization.output_shape]
                first_tanh.output_shapes = [first_tanh.input_shape]
                normalization.outputs = [first_tanh.name]
                normalization.output_indices = [first_tanh.index]
                self._graph.add_edge(normalization, first_tanh)

                for slice_layer, ew_add in zip(slices[1:], ew_adds):
                    slice_layer.outputs = [ew_add.name]
                    slice_layer.output_indices = [ew_add.index]
                    ew_add.input_shapes = [slice_layer.output_shape]
                    ew_add.output_shapes = [ew_add.input_shape]
                    ew_add.inputs = [slice_layer.name]
                    ew_add.input_indices = [slice_layer.index]
                    self._graph.add_edge(slice_layer, ew_add)

                for ew_add, tanh in zip(ew_adds, tanhs[1:]):
                    tanh.inputs = [ew_add.name]
                    tanh.input_indices = [ew_add.index]
                    tanh.input_shapes = [ew_add.output_shape]
                    ew_add.outputs = [tanh.name]
                    ew_add.output_indices = [tanh.index]
                    self._graph.add_edge(ew_add, tanh)

                tanhs[-1].outputs = []
                tanhs[-1].output_indices = []
                tanhs[-1].output_shapes = []

                for tanh, r_conv in zip(tanhs[:-1], r_convs):
                    tanh.outputs = [r_conv.name]
                    tanh.output_indices = [r_conv.index]
                    tanh.output_shapes = [tanh.input_shape]
                    r_conv.inputs = [tanh.name]
                    r_conv.input_indices = [tanh.index]
                    r_conv.input_shapes = [tanh.output_shape]
                    r_conv.output_shapes = [r_conv.input_shape]
                    self._graph.add_edge(tanh, r_conv)

                for r_conv, ew_add in zip(r_convs, ew_adds):
                    r_conv.outputs = [ew_add.name]
                    r_conv.output_indices = [ew_add.index]
                    ew_add.append_input_layer(r_conv.name)
                    ew_add.append_input_index(r_conv.index)
                    ew_add.append_input_shape(r_conv.output_shape)
                    self._graph.add_edge(r_conv, ew_add)

                concat.inputs = []
                concat.input_indices = []
                concat.input_shapes = []
                for tanh in tanhs:
                    tanh.append_output_layer(concat.name)
                    tanh.append_output_index(concat.index)
                    tanh.append_output_shape(tanh.input_shape)
                    concat.append_input_layer(tanh.name)
                    concat.append_input_index(tanh.index)
                    concat.append_input_shape(tanh.output_shape)
                    self._graph.add_edge(tanh, concat)

                for pred in list(self._graph.predecessors(layer)):
                    pred.replace_output_layer(layer.name, w_conv.name)
                    pred.replace_output_index(layer.index, w_conv.index)
                    self._graph.remove_edge(pred, layer)
                    self._graph.add_edge(pred, w_conv)

                for succ in list(self._graph.successors(layer)):
                    succ.replace_input_index(layer.index, concat.index)
                    succ.replace_input_layer(layer.name, concat.name)
                    self._graph.remove_edge(layer, succ)
                    self._graph.add_edge(concat, succ)
                    HailoNN.update_successors_meta_data(succ, successors_meta_data)

                layers_to_remove.append(layer)

                self._logger.debug(f"Unrolled RNN layer {layer.name} to {seq_len} blocks")
                if seq_len > 32:
                    self._logger.warning(
                        f"The number of layers produced by RNN layer {layer.original_names[0]} "
                        f"unrolling is very large and may effect the performance",
                    )
                self._logger.debug(f"Unrolled RNN layer {layer.name} to {seq_len} blocks")

        # finalize graph manipulation outside graph iteration
        if layers_to_remove:
            for layer in layers_to_remove:
                self._graph.remove_layer(layer)

        for layer in new_layers:
            self._graph.relax_new_layer_into_graph(layer, successors_meta_data)

    def _split_lstm_layers(self):
        new_layers = []
        layers_to_remove = []
        successors_meta_data = {}

        for layer in list(self._graph):
            if layer.op != LayerType.lstm:
                continue

            curr_index = self._graph.get_next_index()
            seq_len = layer.input_width
            hidden_shape = [-1, 1, 1, layer.kernel_shape[-1] // 4]
            curr_new_layers = []

            w_conv, concat, curr_index = self._build_one_direction_lstm(
                curr_new_layers,
                layer,
                curr_index,
                seq_len,
                hidden_shape,
            )
            last_layer = concat

            if layer.direction == "bidirectional":
                bw_w_conv, bw_concat, curr_index = self._build_one_direction_lstm(
                    curr_new_layers,
                    layer,
                    curr_index,
                    seq_len,
                    hidden_shape,
                    direction="bw",
                )

                bidirectional_concat = ConcatLayer.from_layer(layer)
                bidirectional_concat.index = curr_index
                curr_index += 1
                bidirectional_concat.name = f"concat_bidirectional_{layer.name}"
                curr_new_layers.append(bidirectional_concat)
                self._graph.add_node(bidirectional_concat)
                last_layer = bidirectional_concat
                bidirectional_concat.inputs = [concat.name, bw_concat.name]
                bidirectional_concat.input_indices = [concat.index, bw_concat.index]
                bidirectional_concat.input_shapes = [hidden_shape] * 2
                for curr_concat in [concat, bw_concat]:
                    curr_concat.outputs = [bidirectional_concat.name]
                    curr_concat.output_indices = [bidirectional_concat.index]
                    curr_concat.output_shapes = [hidden_shape]
                    self._graph.add_edge(curr_concat, bidirectional_concat)

            for pred in list(self._graph.predecessors(layer)):
                pred.replace_output_layer(layer.name, w_conv.name)
                pred.replace_output_index(layer.index, w_conv.index)
                self._graph.remove_edge(pred, layer)
                self._graph.add_edge(pred, w_conv)
                if layer.direction == "bidirectional":
                    pred.append_output_layer(bw_w_conv.name)
                    pred.append_output_index(bw_w_conv.index)
                    pred.append_output_shape(layer.input_shape)
                    self._graph.add_edge(pred, bw_w_conv)

            for succ in list(self._graph.successors(layer)):
                succ.replace_input_index(layer.index, last_layer.index)
                succ.replace_input_layer(layer.name, last_layer.name)
                self._graph.remove_edge(layer, succ)
                self._graph.add_edge(last_layer, succ)
                HailoNN.update_successors_meta_data(succ, successors_meta_data)

            layers_to_remove.append(layer)
            new_layers.extend(curr_new_layers)
            for new_layer in curr_new_layers:
                new_layer.block_info = (BlockType.LSTM, layer.name)

            self._logger.debug(f"Unrolled LSTM layer {layer.name} to {seq_len} blocks")

            if seq_len > self.LSTM_SEQ_LEN_WARNING_LIMIT // layer.num_directions:
                raise BackendFuserException(
                    f"The number of layers produced by LSTM layer {layer.original_names[0]} "
                    f"unrolling is very large and may affect the performance ({seq_len*layer.num_directions}"
                    f" repeated blocks, total of {len(curr_new_layers)} distinct layers).",
                )

        if layers_to_remove:
            for layer in layers_to_remove:
                self._graph.remove_layer(layer)

        for layer in new_layers:
            self._graph.relax_new_layer_into_graph(layer, successors_meta_data)

        self._finalize_lstm_block(new_layers)

    def _build_one_direction_lstm(self, new_layers, layer, curr_index, seq_len, hidden_shape, direction="fw"):
        if direction == "fw":
            conv_w_kernel = layer.kernel
            conv_w_bias = layer.bias
            recurrent_kernel = layer.recurrent_kernel
            recurrent_bias = layer.recurrent_bias
            initial_h = layer.initial_h
            initial_c = layer.initial_c
        else:
            conv_w_kernel = layer.bw_kernel
            conv_w_bias = layer.bw_bias
            recurrent_kernel = layer.bw_recurrent_kernel
            recurrent_bias = layer.bw_recurrent_bias
            initial_h = layer.bw_initial_h
            initial_c = layer.bw_initial_c

        if initial_h is None:
            initial_h = np.zeros([1] + hidden_shape[1:])
        initial_c = np.zeros(hidden_shape[-1]) if initial_c is None else initial_c.flatten()

        w_conv = Conv2DLayer.from_layer(layer)
        w_conv.index = curr_index
        curr_index += 1
        w_conv.name = f"conv_w_{direction}_{layer.name}"
        w_conv.op = LayerType.base_conv
        output_shape = layer.input_shape.copy()
        output_shape[-1] = layer.kernel.shape[-1]
        w_conv.output_shapes = [output_shape]
        w_conv.kernel = conv_w_kernel
        w_conv.kernel_shape = layer.kernel.shape
        w_conv.bias = conv_w_bias
        w_conv.strides = [1, 1, 1, 1]
        w_conv.dilations = [1, 1, 1, 1]
        w_conv.padding = PaddingType.valid
        self._graph.add_node(w_conv)
        new_layers.append(w_conv)

        r_i_kernel, r_o_kernel, r_f_kernel, r_g_kernel = np.split(recurrent_kernel, 4, axis=-1)

        if recurrent_bias is not None:
            r_i_bias, r_o_bias, r_f_bias, r_g_bias = np.split(recurrent_bias, 4)
        else:
            r_i_bias, r_o_bias, r_f_bias, r_g_bias = (None,) * 4

        normalization_r_i = self._build_h0_normalization(
            layer,
            new_layers,
            curr_index,
            r_i_kernel,
            r_i_bias,
            "i",
            ActivationType.sigmoid,
            initial_h,
            direction,
            hidden_shape,
        )
        curr_index += 1
        normalization_r_o = self._build_h0_normalization(
            layer,
            new_layers,
            curr_index,
            r_o_kernel,
            r_o_bias,
            "o",
            ActivationType.sigmoid,
            initial_h,
            direction,
            hidden_shape,
        )
        curr_index += 1
        normalization_r_f = self._build_h0_normalization(
            layer,
            new_layers,
            curr_index,
            r_f_kernel,
            r_f_bias,
            "f",
            ActivationType.sigmoid,
            initial_h,
            direction,
            hidden_shape,
        )
        curr_index += 1
        normalization_r_g = self._build_h0_normalization(
            layer,
            new_layers,
            curr_index,
            r_g_kernel,
            r_g_bias,
            "g",
            ActivationType.tanh,
            initial_h,
            direction,
            hidden_shape,
        )
        curr_index += 1

        normalization_c0 = NormalizationLayer.from_layer(layer)
        normalization_c0.index = curr_index
        curr_index += 1
        normalization_c0.name = f"normalization_c0_{direction}_{layer.name}"

        normalization_c0.mean = [0] * (hidden_shape[-1])
        # Ignore divide by zero error because we will divide again when converting std to kernel
        with np.errstate(divide="ignore"):
            normalization_c0.std = (1 / initial_c).tolist()
        self._graph.add_node(normalization_c0)
        new_layers.append(normalization_c0)

        ew_mults_ig = []
        ew_adds_c = []
        tanhs_c = []
        slices = []
        fss = []
        ew_mults_h = []
        for i in range(1, seq_len + 1):
            slice_layer = SliceLayer.from_layer(layer)
            slice_layer.index = curr_index
            curr_index += 1
            slice_layer.name = f"slice_{i}_{direction}_{layer.name}"
            slice_layer.height_slice = [0, w_conv.output_height, 1]
            slice_layer.width_slice = [i - 1, i, 1]
            slice_layer.features_slice = [0, w_conv.output_features, 1]
            slice_layer.input_shapes = [w_conv.output_shape]
            slice_output_shape = slice_layer.input_shape.copy()
            slice_output_shape[2] = 1
            slice_layer.output_shapes = [slice_output_shape]
            if direction == "fw":
                slices.append(slice_layer)
            else:
                slices.insert(0, slice_layer)
            self._graph.add_node(slice_layer)
            new_layers.append(slice_layer)

            feature_splitter = FeatureSplitterLayer.from_layer(layer)
            feature_splitter.index = curr_index
            curr_index += 1
            feature_splitter.name = f"feature_splitter_{i}_{direction}_{layer.name}"
            feature_splitter.input_shapes = [slice_layer.output_shape]
            feature_splitter.output_shapes = [hidden_shape]
            fss.append(feature_splitter)
            self._graph.add_node(feature_splitter)
            new_layers.append(feature_splitter)

            ew_mult_ig = EWMultLayer.from_layer(layer)
            ew_mult_ig.index = curr_index
            curr_index += 1
            ew_mult_ig.name = f"ew_mult_ig_{i}_{direction}_{layer.name}"
            ew_mult_ig.input_shapes = [hidden_shape]
            ew_mult_ig.output_shapes = [hidden_shape]
            ew_mults_ig.append(ew_mult_ig)
            self._graph.add_node(ew_mult_ig)
            new_layers.append(ew_mult_ig)

            ew_add_c = EWAddLayer.from_layer(layer)
            ew_add_c.index = curr_index
            curr_index += 1
            ew_add_c.name = f"ew_add_c_{i}_{direction}_{layer.name}"
            ew_add_c.input_shapes = [hidden_shape]
            ew_add_c.output_shapes = [hidden_shape]
            ew_adds_c.append(ew_add_c)
            self._graph.add_node(ew_add_c)
            new_layers.append(ew_add_c)

            tanh_c = ActivationLayer.from_layer(layer)
            tanh_c.index = curr_index
            curr_index += 1
            tanh_c.name = f"activation_c_{i}_{direction}_{layer.name}"
            tanh_c.activation = ActivationType.tanh
            tanhs_c.append(tanh_c)
            self._graph.add_node(tanh_c)
            new_layers.append(tanh_c)

            ew_mult_h = EWMultLayer.from_layer(layer)
            ew_mult_h.index = curr_index
            curr_index += 1
            ew_mult_h.name = f"ew_mult_h_{i}_{direction}_{layer.name}"
            ew_mult_h.input_shapes = [hidden_shape]
            ew_mult_h.output_shapes = [hidden_shape]
            ew_mults_h.append(ew_mult_h)
            self._graph.add_node(ew_mult_h)
            new_layers.append(ew_mult_h)

        convs_r_i = []
        ew_adds_i = []
        activations_i = []
        convs_r_o = []
        ew_adds_o = []
        activations_o = []
        convs_r_f = []
        ew_adds_f = []
        activations_f = []
        convs_r_g = []
        ew_adds_g = []
        activations_g = []
        ew_mults_c = [normalization_c0]
        for i in range(2, seq_len + 1):
            self._build_r_conv_and_add(
                layer,
                curr_index,
                new_layers,
                r_i_kernel,
                r_i_bias,
                "i",
                ActivationType.sigmoid,
                convs_r_i,
                ew_adds_i,
                activations_i,
                i,
                direction,
            )
            curr_index += 3
            self._build_r_conv_and_add(
                layer,
                curr_index,
                new_layers,
                r_o_kernel,
                r_o_bias,
                "o",
                ActivationType.sigmoid,
                convs_r_o,
                ew_adds_o,
                activations_o,
                i,
                direction,
            )
            curr_index += 3
            self._build_r_conv_and_add(
                layer,
                curr_index,
                new_layers,
                r_f_kernel,
                r_f_bias,
                "f",
                ActivationType.sigmoid,
                convs_r_f,
                ew_adds_f,
                activations_f,
                i,
                direction,
            )
            curr_index += 3
            self._build_r_conv_and_add(
                layer,
                curr_index,
                new_layers,
                r_g_kernel,
                r_g_bias,
                "g",
                ActivationType.tanh,
                convs_r_g,
                ew_adds_g,
                activations_g,
                i,
                direction,
            )
            curr_index += 3

            ew_mult_c = EWMultLayer.from_layer(layer)
            ew_mult_c.index = curr_index
            curr_index += 1
            ew_mult_c.name = f"ew_mult_c_{i}_{direction}_{layer.name}"
            ew_mult_c.input_shapes = [hidden_shape]
            ew_mult_c.output_shapes = [hidden_shape]
            ew_mults_c.append(ew_mult_c)
            self._graph.add_node(ew_mult_c)
            new_layers.append(ew_mult_c)

        w_conv.outputs = []
        w_conv.output_indices = []
        w_conv.output_shapes = []
        for i, (slice_layer, feature_splitter) in enumerate(zip(slices, fss)):
            w_conv.append_output_layer(slice_layer.name)
            w_conv.append_output_index(slice_layer.index)
            w_conv.append_output_shape(slice_layer.input_shape)
            slice_layer.inputs = [w_conv.name]
            slice_layer.input_indices = [w_conv.index]
            self._graph.add_edge(w_conv, slice_layer)
            slice_layer.outputs = [feature_splitter.name]
            slice_layer.output_indices = [feature_splitter.index]
            feature_splitter.inputs = [slice_layer.name]
            feature_splitter.input_indices = [slice_layer.index]
            self._graph.add_edge(slice_layer, feature_splitter)
            if i == 0:
                adds_c = [normalization_r_i, normalization_r_o, normalization_r_f, normalization_r_g]
            else:
                adds_c = [ew_adds_i[i - 1], ew_adds_o[i - 1], ew_adds_f[i - 1], ew_adds_g[i - 1]]
            feature_splitter.outputs = []
            feature_splitter.output_indices = []
            feature_splitter.output_shapes = []
            for add_c in adds_c:
                feature_splitter.append_output_layer(add_c.name)
                feature_splitter.append_output_index(add_c.index)
                feature_splitter.append_output_shape(hidden_shape)
                add_c.inputs = [feature_splitter.name]
                add_c.input_indices = [feature_splitter.index]
                add_c.input_shapes = [hidden_shape]
                self._graph.add_edge(feature_splitter, add_c)

        normalization_r_f.outputs = [normalization_c0.name]
        normalization_r_f.output_indices = [normalization_c0.index]
        normalization_r_f.output_shapes = [hidden_shape]
        normalization_c0.inputs = [normalization_r_f.name]
        normalization_c0.input_indices = [normalization_r_f.index]
        normalization_c0.input_shapes = [hidden_shape]
        self._graph.add_edge(normalization_r_f, normalization_c0)

        for i, ew_mult_ig in enumerate(ew_mults_ig):
            if i == 0:
                gate = normalization_r_g
                i = normalization_r_i
            else:
                gate = activations_g[i - 1]
                i = activations_i[i - 1]
            ew_mult_ig.inputs = [gate.name, i.name]
            ew_mult_ig.input_indices = [gate.index, i.index]
            ew_mult_ig.input_shapes = [hidden_shape, hidden_shape]
            ew_mult_ig.output_shapes = [hidden_shape]
            gate.outputs = [ew_mult_ig.name]
            gate.output_indices = [ew_mult_ig.index]
            gate.input_shapes = [hidden_shape]
            gate.output_shapes = [hidden_shape]
            i.outputs = [ew_mult_ig.name]
            i.output_indices = [ew_mult_ig.index]
            i.input_shapes = [hidden_shape]
            i.output_shapes = [hidden_shape]
            self._graph.add_edge(gate, ew_mult_ig)
            self._graph.add_edge(i, ew_mult_ig)

        for ew_add_c, ew_mult_ig, ew_mult_c, tanh_c in zip(ew_adds_c, ew_mults_ig, ew_mults_c, tanhs_c):
            ew_add_c.inputs = [ew_mult_ig.name, ew_mult_c.name]
            ew_add_c.input_indices = [ew_mult_ig.index, ew_mult_c.index]
            ew_add_c.input_shapes = [hidden_shape, hidden_shape]
            ew_mult_ig.outputs = [ew_add_c.name]
            ew_mult_ig.output_indices = [ew_add_c.index]
            ew_mult_ig.output_shapes = [hidden_shape]
            ew_mult_c.outputs = [ew_add_c.name]
            ew_mult_c.output_indices = [ew_add_c.index]
            ew_mult_c.output_shapes = [hidden_shape]
            self._graph.add_edge(ew_mult_ig, ew_add_c)
            self._graph.add_edge(ew_mult_c, ew_add_c)

            ew_add_c.outputs = [tanh_c.name]
            ew_add_c.output_indices = [tanh_c.index]
            ew_add_c.output_shapes = [hidden_shape]
            tanh_c.inputs = [ew_add_c.name]
            tanh_c.input_indices = [ew_add_c.index]
            tanh_c.input_shapes = [hidden_shape]
            self._graph.add_edge(ew_add_c, tanh_c)

        for i, (tanh_c, ew_mult_h) in enumerate(zip(tanhs_c, ew_mults_h)):
            output_gate = normalization_r_o if i == 0 else activations_o[i - 1]
            tanh_c.outputs = [ew_mult_h.name]
            tanh_c.output_indices = [ew_mult_h.index]
            tanh_c.output_shapes = [hidden_shape]
            output_gate.outputs = [ew_mult_h.name]
            output_gate.output_indices = [ew_mult_h.index]
            output_gate.output_shapes = [hidden_shape]
            ew_mult_h.inputs = [output_gate.name, tanh_c.name]
            ew_mult_h.input_indices = [output_gate.index, tanh_c.index]
            ew_mult_h.input_shapes = [hidden_shape, hidden_shape]
            ew_mult_h.outputs = []
            ew_mult_h.output_indices = []
            ew_mult_h.output_shapes = []
            self._graph.add_edge(tanh_c, ew_mult_h)
            self._graph.add_edge(output_gate, ew_mult_h)

        r_convs_groups = [convs_r_i, convs_r_o, convs_r_f, convs_r_g]
        for r_convs in r_convs_groups:
            for ew_mult_h, r_conv in zip(ew_mults_h, r_convs):
                ew_mult_h.append_output_layer(r_conv.name)
                ew_mult_h.append_output_index(r_conv.index)
                ew_mult_h.append_output_shape(hidden_shape)
                r_conv.inputs = [ew_mult_h.name]
                r_conv.input_indices = [ew_mult_h.index]
                r_conv.output_shapes = [hidden_shape]
                self._graph.add_edge(ew_mult_h, r_conv)

        ew_adds_groups = [ew_adds_i, ew_adds_o, ew_adds_f, ew_adds_g]
        activations_groups = [activations_i, activations_o, activations_f, activations_g]

        for r_convs, ew_adds, activations in zip(r_convs_groups, ew_adds_groups, activations_groups):
            for r_conv, ew_add, activation in zip(r_convs, ew_adds, activations):
                r_conv.outputs = [ew_add.name]
                r_conv.output_indices = [ew_add.index]
                r_conv.output_shapes = [hidden_shape]
                ew_add.append_input_layer(r_conv.name)
                ew_add.append_input_index(r_conv.index)
                ew_add.append_input_shape(hidden_shape)
                self._graph.add_edge(r_conv, ew_add)

                ew_add.outputs = [activation.name]
                ew_add.output_indices = [activation.index]
                ew_add.output_shapes = [hidden_shape]
                activation.inputs = [ew_add.name]
                activation.input_indices = [ew_add.index]
                activation.input_shapes = [hidden_shape]
                self._graph.add_edge(ew_add, activation)

        for ew_mult_c, ew_add_c, sigmoid_f in zip(ew_mults_c[1:], ew_adds_c[:-1], activations_f):
            ew_add_c.append_output_layer(ew_mult_c.name)
            ew_add_c.append_output_index(ew_mult_c.index)
            ew_add_c.append_output_shape(hidden_shape)
            sigmoid_f.outputs = [ew_mult_c.name]
            sigmoid_f.output_indices = [ew_mult_c.index]
            sigmoid_f.output_shapes = [hidden_shape]
            ew_mult_c.inputs = [ew_add_c.name, sigmoid_f.name]
            ew_mult_c.input_indices = [ew_add_c.index, sigmoid_f.index]
            ew_mult_c.input_shapes = [hidden_shape, hidden_shape]
            self._graph.add_edge(ew_add_c, ew_mult_c)
            self._graph.add_edge(sigmoid_f, ew_mult_c)

        ew_mults_h_for_concat = ew_mults_h if direction == "fw" else ew_mults_h[::-1]
        concat = ConcatLayer.from_layer(layer)
        concat.axis = ConcatAxis.spatial_w
        concat.index = curr_index
        curr_index += 1
        concat.name = f"concat_{direction}_{layer.name}"
        self._graph.add_node(concat)
        new_layers.append(concat)

        concat.inputs = []
        concat.input_indices = []
        concat.input_shapes = []
        for ew_mult_h in ew_mults_h_for_concat:
            ew_mult_h.append_output_layer(concat.name)
            ew_mult_h.append_output_index(concat.index)
            ew_mult_h.append_output_shape(hidden_shape)
            concat.append_input_layer(ew_mult_h.name)
            concat.append_input_index(ew_mult_h.index)
            concat.append_input_shape(hidden_shape)
            self._graph.add_edge(ew_mult_h, concat)

        return w_conv, concat, curr_index

    def _build_h0_normalization(
        self,
        layer,
        new_layers,
        curr_index,
        kernel,
        bias,
        name,
        activation,
        initial_h,
        direction,
        hidden_shape,
    ):
        normalization = NormalizationLayer.from_layer(layer)
        normalization.index = curr_index
        normalization.name = f"normalization_h0_{name}_{direction}_{layer.name}"

        kernel_initializer = tf.keras.initializers.Constant(kernel)
        bias_initializer = tf.keras.initializers.Constant(bias) if layer.recurrent_bias is not None else "zeros"
        conv_h = tf.keras.layers.Conv2D(
            filters=hidden_shape[-1],
            kernel_size=(1, 1),
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer,
        )
        normalization.mean = np.array(-conv_h(initial_h)).flatten().tolist()
        normalization.std = [1] * hidden_shape[-1]
        normalization.activation = activation
        self._graph.add_node(normalization)
        new_layers.append(normalization)

        return normalization

    def _build_r_conv_and_add(
        self,
        layer,
        curr_index,
        new_layers,
        kernel,
        bias,
        name,
        activation_type,
        r_convs,
        ew_adds,
        activations,
        i,
        direction,
    ):
        r_conv = Conv2DLayer.from_layer(layer)
        r_conv.index = curr_index
        r_conv.name = f"conv_r_{name}_{i}_{direction}_{layer.name}"
        r_conv.op = LayerType.base_conv
        r_conv.kernel = kernel
        r_conv.bias = bias
        r_conv.kernel_shape = r_conv.kernel.shape
        r_conv.strides = [1, 1, 1, 1]
        r_conv.dilations = [1, 1, 1, 1]
        r_conv.padding = PaddingType.valid
        self._graph.add_node(r_conv)
        new_layers.append(r_conv)
        r_convs.append(r_conv)

        ew_add = EWAddLayer.from_layer(layer)
        ew_add.index = curr_index + 1
        ew_add.name = f"ew_add_{name}_{i}_{direction}_{layer.name}"
        self._graph.add_node(ew_add)
        new_layers.append(ew_add)
        ew_adds.append(ew_add)

        activation = ActivationLayer.from_layer(layer)
        activation.index = curr_index + 2
        activation.name = f"activation_{name}_{i}_{direction}_{layer.name}"
        activation.activation = activation_type
        self._graph.add_node(activation)
        new_layers.append(activation)
        activations.append(activation)

        return r_conv, ew_add, activation

    def _finalize_lstm_block(self, new_layers):
        new_slices = []
        layers_to_remove = []
        successors_meta_data = {}
        curr_index = self._graph.get_next_index()

        for layer in new_layers:
            # dead layer due to C(t-1) initialization with zeros in LSTM block
            if layer.op == LayerType.normalization and layer.activation == ActivationType.linear:
                if np.count_nonzero(layer.kernel) > 0:
                    continue

                # handle only the edge case with complext slicing (optimization algo doens't solve)
                feature_splitter = look_for_node(
                    self._graph,
                    layer,
                    [BwdChainNode(op=LayerType.normalization), BwdChainNode(op=LayerType.feature_splitter)],
                    exact_match=True,
                )
                ew_add = look_for_node(
                    self._graph,
                    layer,
                    [FwdChainNode(op=LayerType.base_ew_add)],
                    exact_match=True,
                )
                ew_mult = look_for_node(
                    self._graph,
                    layer,
                    [FwdChainNode(op=LayerType.base_ew_add), BwdChainNode(op=LayerType.ew_mult)],
                    exact_match=True,
                )
                ew_mult2 = look_for_node(
                    self._graph,
                    layer,
                    [FwdChainNode(op=LayerType.base_ew_add), FwdChainNode(op=LayerType.ew_mult)],
                    exact_match=True,
                )
                if not all([feature_splitter, ew_add, ew_mult, ew_mult2]):
                    continue

                slice_size = feature_splitter.output_shapes[0][-1]
                norm_ri, norm_ro, norm_rf, norm_rg = list(self._graph.successors(feature_splitter))
                layers_to_remove.extend([layer, feature_splitter, norm_rf, ew_add])

                current_new_slices = []
                for i, norm in enumerate([norm_ri, norm_ro, norm_rf, norm_rg]):
                    if i == 2:  # skip norm_rf specifically
                        continue

                    slice = SliceLayer()
                    slice.index = curr_index
                    curr_index += 1
                    slice.name = f"{feature_splitter.name}_slice_{i}"
                    slice.height_slice = [0, feature_splitter.output_height, 1]
                    slice.width_slice = [0, feature_splitter.output_width, 1]
                    slice.features_slice = [slice_size * i, slice_size * (i + 1), 1]
                    slice.input_shapes = feature_splitter.input_shapes.copy()
                    slice.output_shapes = [feature_splitter.output_shapes[i].copy()]
                    slice.inputs = feature_splitter.inputs.copy()
                    slice.input_indices = feature_splitter.input_indices.copy()
                    slice.outputs = [norm.name]
                    slice.output_indices = [norm.index]
                    current_new_slices.append(slice)

                    self._graph.add_node(slice)
                    self._graph.add_edge(slice, norm)
                    self._graph.remove_edge(feature_splitter, norm)
                    norm.replace_input_layer(feature_splitter.name, slice.name)
                    norm.replace_input_index(feature_splitter.index, slice.index)
                    HailoNN.update_successors_meta_data(norm, successors_meta_data)

                new_slices.extend(current_new_slices)
                for pred in list(self._graph.predecessors(feature_splitter)):
                    pred.outputs = []
                    pred.output_indices = []
                    self._graph.remove_edge(pred, feature_splitter)
                    for slice in current_new_slices:
                        self._graph.add_edge(pred, slice)
                        pred.outputs.append(slice.name)
                        pred.output_indices.append(slice.index)

                ew_mult.outputs = ew_add.outputs.copy()
                ew_mult.output_indices = ew_add.output_indices.copy()
                ew_mult.output_shapes = ew_add.output_shapes.copy()
                for succ in list(self._graph.successors(ew_add)):
                    self._graph.add_edge(ew_mult, succ)
                    self._graph.remove_edge(ew_add, succ)
                    succ.replace_input_layer(ew_add.name, ew_mult.name)
                    succ.replace_input_index(ew_add.index, ew_mult.index)

        for layer in new_slices:
            self._graph.relax_new_layer_into_graph(layer, successors_meta_data)

        if layers_to_remove:
            for layer in layers_to_remove:
                self._graph.remove_layer(layer)

    def _split_gru_layers(self):
        for layer in list(self._graph):
            if layer.op == LayerType.gru:
                # the first input of the gru layer is the input data the second can be the hidden state
                preds = list(self._graph.predecessors(layer))
                base_index = self._graph.get_next_index()
                sequence_length = layer.input_shapes[0][-2]
                gru_input_shape = layer.input_shapes[0]
                hidden_state = None
                new_layers_by_name = {}
                intermediate_outputs = []

                # creates the concat layer for all the intermediate results
                concat = ConcatLayer()
                concat.name = f"gru_concat{base_index}"
                concat.original_names = layer.original_names.copy()
                concat.index = base_index
                concat.axis = ConcatAxis.spatial_w
                concat.input_shapes = [
                    [gru_input_shape[0], gru_input_shape[1], 1, gru_input_shape[3]],
                ] * sequence_length
                concat.output_shapes = [gru_input_shape.copy()]
                concat.block_info = (BlockType.GRU, layer.name)
                base_index += 1

                if sequence_length > 1:
                    # splits the input by the sequence length (channels axis)
                    sequence_splitter = SpatialSplitterLayer()
                    sequence_splitter.index = base_index
                    sequence_splitter.name = f"gru_sequence_splitter_{sequence_splitter.index}"
                    sequence_splitter.original_names = layer.original_names.copy()
                    sequence_splitter.axis = 2
                    # the number of splits is the sequence length which is the w dimension of the input
                    sequence_splitter.split_sizes = [1] * sequence_length
                    sequence_splitter.input_shapes = [gru_input_shape.copy()]
                    sequence_splitter.output_shapes = [
                        [dim if i != 2 else 1 for i, dim in enumerate(gru_input_shape)],
                    ] * sequence_length
                    new_layers_by_name["sequence_width_splitter"] = sequence_splitter
                    concat.block_info = (BlockType.GRU, layer.name)
                    base_index += 1

                    update_conv_pred = sequence_splitter
                    self._fuser_helper.replace_succ(preds[0], layer, sequence_splitter)
                    self._fuser_helper.add_preds(sequence_splitter, [preds[0]], update_input_shapes=False)
                    conv_input_shape = sequence_splitter.output_shapes[0].copy()
                else:
                    update_conv_pred = preds[0]
                    conv_input_shape = layer.input_shapes[0].copy()

                # defuses the gru to supported layers
                for i in range(sequence_length):
                    # adds blocks for the update gate, reset gate
                    for gate in ["update", "reset"]:
                        conv = Conv2DLayer()
                        conv.original_names = layer.original_names.copy()
                        conv.index = base_index
                        conv.name = f"gru_{gate}_conv{base_index}"
                        conv.kernel = layer.kernel if gate == "update" else layer.recurrence_kernel
                        conv.kernel_shape = conv.kernel.shape
                        # the first half of the bias is for the update gate and the second half is for the reset gate
                        conv.bias = (
                            np.array_split(layer.bias, 2, axis=-1)[0]
                            if gate == "update"
                            else np.array_split(layer.bias, 2, axis=-1)[1]
                        )
                        conv.strides = [1, 1, 1, 1]
                        conv.dilations = [1, 1, 1, 1]
                        conv.padding = PaddingType.valid
                        conv.input_shapes = [conv_input_shape.copy()]
                        spatial_shape = input_to_output_height_width(
                            conv.input_shape,
                            conv.kernel_shape,
                            conv.strides,
                            conv.padding,
                        )
                        conv.output_shapes = [[-1, *spatial_shape, conv.kernel_shape[-1]]]
                        conv.activation = ActivationType.linear
                        conv.block_info = (BlockType.GRU, layer.name)
                        base_index += 1
                        new_layers_by_name[f"{gate}_conv"] = conv

                        # splits the data conv to Wz, Wr, Wh
                        conv_splitter = FeatureSplitterLayer()
                        conv_splitter.index = base_index
                        conv_splitter.name = f"gru_{gate}_feature_splitter{conv_splitter.index}"
                        conv_splitter.original_names = layer.original_names.copy()
                        conv_splitter.split_sizes = [layer.hidden_size] * 3
                        conv_splitter.input_shapes = conv.output_shapes.copy()
                        conv_splitter.output_shapes = [[*conv.output_shape[:-1], layer.hidden_size]] * 3
                        new_layers_by_name[f"{gate}_feature_splitter"] = conv_splitter
                        conv_splitter.block_info = (BlockType.GRU, layer.name)
                        base_index += 1

                        base_shape = conv_splitter.output_shapes[0]

                        ew_add = EWAddLayer()
                        ew_add.index = base_index
                        ew_add.name = f"gru_{gate}_ew_add{base_index}"
                        ew_add.original_names = layer.original_names.copy()
                        ew_add.input_shapes = [base_shape.copy()] * 2
                        ew_add.output_shapes = [base_shape.copy()]
                        new_layers_by_name[f"{gate}_ew_add"] = ew_add
                        ew_add.block_info = (BlockType.GRU, layer.name)
                        base_index += 1

                        sigmoid = ActivationLayer()
                        sigmoid.index = base_index
                        sigmoid.name = f"gru_{gate}_sigmoid{base_index}"
                        sigmoid.original_names = layer.original_names.copy()
                        sigmoid.input_shapes = [base_shape.copy()]
                        sigmoid.output_shapes = [base_shape.copy()]
                        new_layers_by_name[f"{gate}_sigmoid"] = sigmoid
                        sigmoid.activation = ActivationType.sigmoid
                        sigmoid.block_info = (BlockType.GRU, layer.name)
                        base_index += 1

                    # completes the reset block
                    ew_mult = EWMultLayer()
                    ew_mult.index = base_index
                    ew_mult.name = f"gru_ew_mult{base_index}"
                    ew_mult.original_names = layer.original_names.copy()
                    ew_mult.input_shapes = [base_shape.copy()] * 2
                    ew_mult.output_shapes = [base_shape.copy()]
                    new_layers_by_name["reset_ew_mult"] = ew_mult
                    ew_mult.block_info = (BlockType.GRU, layer.name)
                    base_index += 1

                    ew_add = EWAddLayer()
                    ew_add.index = base_index
                    ew_add.name = f"gru_reset_ew_add{base_index}"
                    ew_add.original_names = layer.original_names.copy()
                    ew_add.input_shapes = [base_shape.copy()] * 2
                    ew_add.output_shapes = [base_shape.copy()]
                    new_layers_by_name["reset_ew_add2"] = ew_add
                    ew_add.block_info = (BlockType.GRU, layer.name)
                    base_index += 1

                    tanh = ActivationLayer()
                    tanh.index = base_index
                    tanh.name = f"gru_reset_tanh{base_index}"
                    tanh.original_names = layer.original_names.copy()
                    tanh.input_shapes = [base_shape.copy()]
                    tanh.output_shapes = [base_shape.copy()] * 2
                    new_layers_by_name["reset_tanh"] = tanh
                    tanh.activation = ActivationType.tanh
                    tanh.block_info = (BlockType.GRU, layer.name)
                    base_index += 1

                    ew_sub = EWSubLayer()
                    ew_sub.index = base_index
                    ew_sub.name = f"gru_reset_ew_sub{base_index}"
                    ew_sub.original_names = layer.original_names.copy()
                    ew_sub.input_shapes = [base_shape.copy()] * 2
                    ew_sub.output_shapes = [base_shape.copy()]
                    new_layers_by_name["reset_ew_sub"] = ew_sub
                    ew_sub.block_info = (BlockType.GRU, layer.name)
                    base_index += 1

                    ew_mult = EWMultLayer()
                    ew_mult.index = base_index
                    ew_mult.name = f"gru_ew_mult{base_index}"
                    ew_mult.original_names = layer.original_names.copy()
                    ew_mult.input_shapes = [base_shape.copy()] * 2
                    ew_mult.output_shapes = [base_shape.copy()]
                    new_layers_by_name["update_ew_mult"] = ew_mult
                    ew_mult.block_info = (BlockType.GRU, layer.name)
                    base_index += 1

                    ew_add = EWAddLayer()
                    ew_add.index = base_index
                    ew_add.name = f"gru_reset_ew_add{base_index}"
                    ew_add.original_names = layer.original_names.copy()
                    ew_add.input_shapes = [base_shape.copy()] * 2
                    ew_add.output_shapes = [base_shape.copy()]
                    new_layers_by_name["reset_ew_add3"] = ew_add
                    ew_add.block_info = (BlockType.GRU, layer.name)
                    base_index += 1

                    if len(preds) == 1:
                        # the hidden state is given as a weight or input of the previous sequence
                        if i == 0:
                            # hidden state was given as a weight thus creates a const input layer
                            const_input = ConstInputLayer()
                            const_input.index = base_index
                            const_input.name = f"gru_hidden_state_const_input{base_index}"
                            const_input.original_names = layer.original_names.copy()
                            const_input.const_values = layer.initial_h
                            const_input.input_shapes = [[-1, *const_input.const_values.shape]]
                            const_input.output_shapes = [[-1, *const_input.const_values.shape]] * 2
                            const_input.block_info = (BlockType.GRU, layer.name)
                            base_index += 1
                            reset_conv_pred = const_input
                        else:
                            # the hidden state is the output of the previous block
                            reset_conv_pred = hidden_state
                            hidden_state.append_output_shape(base_shape.copy())
                        self._fuser_helper.add_succs(
                            reset_conv_pred,
                            [new_layers_by_name["reset_conv"]],
                            update_output_shapes=False,
                        )
                    else:
                        reset_conv_pred = preds[1]
                        self._fuser_helper.replace_succ(reset_conv_pred, layer, new_layers_by_name["reset_conv"])

                    # connects the new layers
                    layer_to_preds = {
                        new_layers_by_name["update_conv"]: [update_conv_pred],
                        new_layers_by_name["reset_conv"]: [reset_conv_pred],
                        new_layers_by_name["update_feature_splitter"]: [new_layers_by_name["update_conv"]],
                        new_layers_by_name["reset_feature_splitter"]: [new_layers_by_name["reset_conv"]],
                        new_layers_by_name["update_ew_add"]: [
                            new_layers_by_name["update_feature_splitter"],
                            new_layers_by_name["reset_feature_splitter"],
                        ],
                        new_layers_by_name["reset_ew_add"]: [
                            new_layers_by_name["update_feature_splitter"],
                            new_layers_by_name["reset_feature_splitter"],
                        ],
                        new_layers_by_name["update_sigmoid"]: [new_layers_by_name["update_ew_add"]],
                        new_layers_by_name["reset_sigmoid"]: [new_layers_by_name["reset_ew_add"]],
                        new_layers_by_name["reset_ew_mult"]: [
                            new_layers_by_name["reset_feature_splitter"],
                            new_layers_by_name["reset_sigmoid"],
                        ],
                        new_layers_by_name["reset_ew_add2"]: [
                            new_layers_by_name["update_feature_splitter"],
                            new_layers_by_name["reset_ew_mult"],
                        ],
                        new_layers_by_name["reset_tanh"]: [new_layers_by_name["reset_ew_add2"]],
                        new_layers_by_name["reset_ew_sub"]: [reset_conv_pred, new_layers_by_name["reset_tanh"]],
                        new_layers_by_name["update_ew_mult"]: [
                            new_layers_by_name["update_sigmoid"],
                            new_layers_by_name["reset_ew_sub"],
                        ],
                        new_layers_by_name["reset_ew_add3"]: [
                            new_layers_by_name["reset_tanh"],
                            new_layers_by_name["update_ew_mult"],
                        ],
                    }

                    self._fuser_helper.add_succs(update_conv_pred, [new_layers_by_name["update_conv"]])
                    self._fuser_helper.add_succs(reset_conv_pred, [new_layers_by_name["reset_ew_sub"]])

                    layers_to_succs = {
                        new_layers_by_name["update_conv"]: [new_layers_by_name["update_feature_splitter"]],
                        new_layers_by_name["reset_conv"]: [new_layers_by_name["reset_feature_splitter"]],
                        new_layers_by_name["update_feature_splitter"]: [
                            new_layers_by_name["update_ew_add"],
                            new_layers_by_name["reset_ew_add"],
                            new_layers_by_name["reset_ew_add2"],
                        ],
                        new_layers_by_name["reset_feature_splitter"]: [
                            new_layers_by_name["update_ew_add"],
                            new_layers_by_name["reset_ew_add"],
                            new_layers_by_name["reset_ew_mult"],
                        ],
                        new_layers_by_name["update_ew_add"]: [new_layers_by_name["update_sigmoid"]],
                        new_layers_by_name["reset_ew_add"]: [new_layers_by_name["reset_sigmoid"]],
                        new_layers_by_name["update_sigmoid"]: [new_layers_by_name["update_ew_mult"]],
                        new_layers_by_name["reset_sigmoid"]: [new_layers_by_name["reset_ew_mult"]],
                        new_layers_by_name["reset_ew_mult"]: [new_layers_by_name["reset_ew_add2"]],
                        new_layers_by_name["reset_ew_add2"]: [new_layers_by_name["reset_tanh"]],
                        new_layers_by_name["reset_tanh"]: [
                            new_layers_by_name["reset_ew_add3"],
                            new_layers_by_name["reset_ew_sub"],
                        ],
                        new_layers_by_name["reset_ew_sub"]: [new_layers_by_name["update_ew_mult"]],
                        new_layers_by_name["update_ew_mult"]: [new_layers_by_name["reset_ew_add3"]],
                    }

                    self._fuser_helper.handle_new_preds_succs(
                        layer_to_preds,
                        layers_to_succs,
                        update_inputs_shapes=False,
                        update_outputs_shapes=False,
                    )

                    hidden_state = new_layers_by_name["reset_ew_add3"]
                    preds = [update_conv_pred]
                    intermediate_outputs.append(hidden_state)
                    concat.input_list.append(new_layers_by_name["reset_ew_add3"])

                for output in intermediate_outputs:
                    self._fuser_helper.add_preds(concat, [output], update_input_shapes=False)
                    self._fuser_helper.add_succs(output, [concat])

                # adds the succs to the new layers
                succs = list(self._graph.successors(layer))
                for succ in succs:
                    self._fuser_helper.replace_pred(succ, layer, concat)

                self._fuser_helper.add_succs(concat, succs, update_output_shapes=False)
                self._graph.remove_layer(layer)

    def _split_ew_reduce_layers(self):
        new_layers = []
        layers_to_remove = []
        reduce_type_to_layer_class = {
            LayerType.ew_max: ReduceMaxLayer,
            LayerType.ew_min: ReduceMinLayer,
            LayerType.base_ew_mean: ReduceMeanLayer,
        }

        for layer in list(self._graph):
            if layer.op in reduce_type_to_layer_class:
                ew_type_string = layer.op.value.split("_")[1]

                preds = list(self._graph.predecessors(layer))
                succs = list(self._graph.successors(layer))
                base_idx = self._graph.get_next_index()

                grouped_concat = self._fuser_helper.create_layer(
                    ConcatLayer,
                    base_idx,
                    "grouped_concat",
                    layer,
                    new_layers,
                    [layer.output_shape],
                )
                grouped_concat.group_sizes = [1] * layer.input_features
                grouped_concat.input_list = []
                if layer.op in [LayerType.base_ew_mean]:
                    # ops without input_list
                    grouped_concat.input_list = [
                        self._graph.get_layer_by_name(layer_name) for layer_name in layer.inputs
                    ]
                else:
                    grouped_concat.input_list = layer.input_list.copy()
                output_shape = layer.output_shape.copy()
                output_shape[-1] *= len(grouped_concat.input_list)
                grouped_concat.output_shapes = [output_shape]

                reduce_type = reduce_type_to_layer_class[layer.op]
                grouped_reduce = self._fuser_helper.create_layer(
                    reduce_type,
                    base_idx + 1,
                    f"grouped_reduce_{ew_type_string}",
                    layer,
                    new_layers,
                    [layer.output_shape],
                )
                grouped_reduce.groups = grouped_concat.output_shape[-1] // len(grouped_concat.input_list)

                self._fuser_helper.add_preds(grouped_concat, preds)
                for pred in preds:
                    self._fuser_helper.replace_succ(pred, layer, grouped_concat)
                self._fuser_helper.add_preds(grouped_reduce, [grouped_concat])
                self._fuser_helper.add_succs(grouped_concat, [grouped_reduce])
                self._fuser_helper.add_succs(grouped_reduce, succs)
                for succ in succs:
                    self._fuser_helper.replace_pred(succ, layer, grouped_reduce)

                layers_to_remove.append(layer)
                self._logger.debug(f"Replaced {layer.full_name_msg} with Concat-Reduce{ew_type_string} layers.")

        # finalize graph manipulation outside graph iteration
        for layer in layers_to_remove:
            self._graph.remove_layer(layer)

    def _split_equal_layers(self):
        # Case 1: single input and zero dim const
        # Case 2: single input and one dim (features) const
        # Case 3: single input and 3-dim const
        # Case 4: two inputs
        new_layers = []
        layers_to_remove = []
        successors_meta_data = {}

        for layer in list(self._graph):
            curr_index = self._graph.get_next_index()
            if layer.op == LayerType.equal:
                const_layer = None
                preds = list(self._graph.predecessors(layer))
                if len(preds) == 1 and layer.constant_input is None:
                    raise BackendFuserException(f"Can not find two inputs to compare for layer {layer.full_name_msg}")

                if len(preds) == 1 and len(layer.constant_input.shape) in [0, 1]:
                    # Adding normalization for cases 1+2
                    equal_layer = NormalizationLayer.from_layer(layer)
                    equal_layer.mean = layer.constant_input
                    equal_layer.std = [1]
                    equal_layer.update_mean_and_std()
                else:
                    # Adding ew_sub for cases 3+4
                    equal_layer = EWSubLayer.from_layer(layer)

                    if layer.constant_input is not None:
                        # Adding const input layer for case 3
                        const_layer = ConstInputLayer.from_layer(layer)
                        const_layer.name = f"const_input_{layer.name}"
                        const_layer.index = curr_index
                        curr_index += 1
                        const_layer.inputs = []
                        const_layer.input_indices = []
                        const_layer.input_shapes = [-1, *layer.constant_input.shape]
                        const_layer.output_shapes = const_layer.input_shapes.copy()
                        const_layer.move_params(layer)
                        self._graph.add_node(const_layer)
                        new_layers.append(const_layer)

                equal_layer.name = f"{equal_layer.op.value}_{layer.name}"
                equal_layer.index = curr_index
                curr_index += 1
                self._graph.add_node(equal_layer)
                new_layers.append(equal_layer)

                activation_layer = ActivationLayer.from_layer(layer)
                activation_layer.index = curr_index
                curr_index += 1
                activation_layer.name = f"activation_{layer.name}"
                activation_layer.activation = ActivationType.delta
                self._graph.add_node(activation_layer)
                new_layers.append(activation_layer)

                equal_layer.outputs = [activation_layer.name]
                equal_layer.output_indices = [activation_layer.index]
                equal_layer.output_shapes = [layer.output_shape.copy()]
                activation_layer.inputs = [equal_layer.name]
                activation_layer.input_indices = [equal_layer.index]
                activation_layer.input_shapes = [equal_layer.output_shape]
                self._graph.add_edge(equal_layer, activation_layer)

                for pred in preds:
                    self._graph.remove_edge(pred, layer)
                    self._graph.add_edge(pred, equal_layer)
                    pred.replace_output_layer(layer.name, equal_layer.name)
                    pred.replace_output_index(layer.index, equal_layer.index)

                if const_layer is not None:
                    const_layer.outputs = [equal_layer.name]
                    const_layer.output_indices = [equal_layer.index]
                    self._graph.add_edge(const_layer, equal_layer)
                    equal_layer.append_input_layer(const_layer.name)
                    equal_layer.append_input_index(const_layer.index)
                    equal_layer.append_input_shape(const_layer.output_shape)

                succs = list(self._graph.successors(layer))
                for succ in succs:
                    self._graph.remove_edge(layer, succ)
                    self._graph.add_edge(activation_layer, succ)
                    succ.replace_input_layer(layer.name, activation_layer.name)
                    succ.replace_input_index(layer.index, activation_layer.index)
                    HailoNN.update_successors_meta_data(succ, successors_meta_data)

                layers_to_remove.append(layer)

        if layers_to_remove:
            for layer in layers_to_remove:
                self._graph.remove_layer(layer)

        for layer in new_layers:
            self._graph.relax_new_layer_into_graph(layer, successors_meta_data)

    def _split_spatial_split_layers(self):
        new_layers = []
        layers_to_remove = []
        successors_meta_data = {}

        for layer in list(self._graph):
            if layer.op == LayerType.spatial_splitter:
                pred = next(iter(self._graph.predecessors(layer)))
                succs = list(self._graph.successors(layer))

                curr_index = self._graph.get_next_index()
                base_idx = 0
                new_slices = []
                num_slices = len(layer.split_sizes)

                for i in range(num_slices):
                    slice_layer = SliceLayer()
                    slice_layer.index = curr_index + i
                    slice_layer.name = f"{layer.name}_slice{i}"
                    new_slices.append(slice_layer)

                    slice_layer.inputs = layer.inputs.copy()
                    slice_layer.input_indices = layer.input_indices.copy()
                    slice_layer.input_shapes = layer.input_shapes.copy()

                    slice_layer.outputs = [layer.outputs[i]]
                    slice_layer.output_indices = [layer.output_indices[i]]
                    slice_layer.output_shapes = [layer.output_shapes[i]]

                    height_slice = [0, layer.input_height]
                    width_slice = [0, layer.input_width]
                    features_slice = [0, layer.input_features]
                    slice_end = base_idx + layer.split_sizes[i]
                    modified_slice = [base_idx, slice_end]
                    height_slice = modified_slice if layer.axis == 1 else height_slice
                    width_slice = modified_slice if layer.axis == 2 else width_slice
                    slice_layer.set_slices_dims(height_slice, width_slice, features_slice)
                    base_idx = slice_end

                    for orig_name in layer.original_names:
                        slice_layer.add_original_name(orig_name)

                    self._graph.add_node(slice_layer)
                    self._graph.add_edge(pred, slice_layer)

                    self._graph.add_edge(slice_layer, succs[i])
                    self._graph.remove_edge(layer, succs[i])
                    succs[i].replace_input_layer(layer.name, slice_layer.name)
                    succs[i].replace_input_index(layer.index, slice_layer.index)
                    HailoNN.update_successors_meta_data(succs[i], successors_meta_data)

                pred.outputs = [x.name for x in new_slices]
                pred.output_indices = [x.index for x in new_slices]
                self._graph.remove_edge(pred, layer)

                new_layers.extend(new_slices)
                layers_to_remove.append(layer)

                self._logger.debug(f"Replaced {layer.op.value} layer {layer.name} with equivalent slices")

        if layers_to_remove:
            for layer in layers_to_remove:
                self._graph.remove_layer(layer)

        for layer in new_layers:
            self._graph.relax_new_layer_into_graph(layer, successors_meta_data)

    def _collapse_slice_chains(self):
        """
        Find chains of slice layers and collapse them.
        """
        layers_to_remove = []
        for layer in list(self._graph):
            slices_chain = []
            if layer not in layers_to_remove and layer.op == LayerType.base_slice:
                # start looking for consecutive slices in a chain
                slices_chain.append(layer)
                chain_in_progress = True
                chain_anchor = layer
                while chain_in_progress:
                    succs = list(self._graph.successors(chain_anchor))
                    if len(succs) == 1 and succs[0].op == LayerType.base_slice:
                        chain_anchor = succs[0]
                        slices_chain.append(succs[0])
                    else:
                        chain_in_progress = False

            if len(slices_chain) > 1:
                # found a chain of slices - collapse them into one slice layer
                height_slices = slices_chain[0].height_slice[:2]
                width_slices = slices_chain[0].width_slice[:2]
                features_slices = slices_chain[0].features_slice[:2]

                # sum/subtract total slices start/end values
                for i in range(1, len(slices_chain)):
                    height_slices[0] += slices_chain[i].height_slice[0]
                    height_slices[1] -= slices_chain[i].input_height - slices_chain[i].height_slice[1]
                    width_slices[0] += slices_chain[i].width_slice[0]
                    width_slices[1] -= slices_chain[i].input_width - slices_chain[i].width_slice[1]
                    features_slices[0] += slices_chain[i].features_slice[0]
                    features_slices[1] -= slices_chain[i].input_features - slices_chain[i].features_slice[1]
                    layers_to_remove.append(slices_chain[i])

                # validate new slice parameters
                negative_start_cond = height_slices[0] < 0 or width_slices[0] < 0 or features_slices[0] < 0
                invalid_end_cond = (
                    height_slices[1] > layer.input_height
                    or width_slices[1] > layer.input_width
                    or features_slices[1] > layer.input_features
                )
                zero_slice_cond = (
                    height_slices[0] == height_slices[1]
                    or width_slices[0] == width_slices[1]
                    or features_slices[0] == features_slices[1]
                )

                if negative_start_cond or invalid_end_cond or zero_slice_cond:
                    raise BackendFuserException(
                        f"Unexpected Slice values in fused chain of slices near {layer.name}: "
                        f"The totals of the slice starts/ends have resulted in unsupported "
                        f"values for slicing. height_slice={height_slices[:2]}, "
                        f"width_slice={width_slices[:2]}, features_slice={features_slices[:2]}."
                        f" (translated from {layer.original_names}).",
                    )

                # incase of two consecutive slices one on w and ther other on h with steps != 1, the steps should be changed
                chain_height_slices = [slice_layer.height_slice[-1] for slice_layer in slices_chain]
                chain_width_slices = [slice_layer.width_slice[-1] for slice_layer in slices_chain]
                if len(slices_chain) == 2 and (
                    1 in chain_height_slices  # one of steps is 1
                    and np.prod(chain_height_slices) != 1  # one of steps is not 1
                    and 1 in chain_width_slices
                    and np.prod(chain_width_slices) != 1
                ):
                    width_steps = max(chain_width_slices)
                    height_steps = max(chain_height_slices)
                else:
                    width_steps = chain_width_slices[0]
                    height_steps = chain_height_slices[0]

                # set slice layer with finalized parameters and shapes
                layer.height_slice = [*height_slices, height_steps]
                layer.width_slice = [*width_slices, width_steps]
                layer.features_slice = [*features_slices, layer.features_slice[-1]]

                last_layer = slices_chain[-1]
                for name in last_layer.original_names:
                    layer.add_original_name(name)

                # bypass all middle slices between first and last
                for out_layer in list(self._graph.successors(last_layer)):
                    self._graph.add_edge(layer, out_layer)
                    self._graph.remove_edge(last_layer, out_layer)
                    out_layer.append_input_layer(layer.name)
                    out_layer.append_input_index(layer.index)
                    out_layer.append_input_shapes(layer.output_shapes)
                    layer.append_output_layer(out_layer.name)
                    layer.append_output_index(out_layer.index)

        # finalize graph manipulation outside graph iteration
        if layers_to_remove:
            for layer in layers_to_remove:
                self._graph.remove_layer(layer)
            self._graph.calculate_shapes()

    def _handle_strided_slices(self):
        """
        Handle strided slices and verify all were handled.
        """
        self._convert_strided_slices_to_row_splitters()

    def _convert_strided_slices_to_row_splitters(self):
        """
        Find row splitter implemented by strided slices and replace with row split layers.
        """
        new_layers = []
        # layers_to_remove is a dict in order to match _convert_slices_to_feature_splitter which
        # also calls _create_splitter (here there isn't a use of the dict values but only the keys)
        layers_to_remove = {}
        successors_meta_data = {}
        num_slices_in_graph = 0

        for layer in list(self._graph):
            if layer in layers_to_remove:
                continue

            # find successors that are strided slices
            stride = 0
            num_slices = 0
            is_row_splitter = True
            succs = list(self._graph.successors(layer))
            for succ in succs:
                if succ.op == LayerType.base_slice:
                    num_slices += 1
                    if stride not in {0, succ.height_slice[2]} or len(succ.input_shape) != 4:
                        is_row_splitter = False
                    stride = succ.height_slice[2]

            if num_slices == 0 or not is_row_splitter:
                num_slices_in_graph += num_slices
                continue

            # check that slicing is only on height and that row distribution is correct
            strided_slices = []
            start_indices_covered = [False] * num_slices
            for succ in succs:
                if succ.op == LayerType.base_slice:
                    if succ.height_slice[0] >= num_slices:
                        is_row_splitter = False
                    else:
                        start_indices_covered[succ.height_slice[0]] = True

                    height_cond = succ.output_height >= floor((succ.input_height - succ.height_slice[0]) / stride)
                    width_cond = succ.output_width == layer.output_width
                    features_cond = succ.output_shapes[0][-1] == layer.output_shapes[0][-1]
                    is_row_splitter = is_row_splitter and height_cond and width_cond and features_cond
                    strided_slices.append(succ)

            # sort the slices by slicing start index, so that output_indices are sorted correctly
            strided_slices = sorted(strided_slices, key=lambda x: x.height_slice[0])

            # check that all rows are covered
            if len(strided_slices) > 1 and any(x is False for x in start_indices_covered):
                is_row_splitter = False

            if len(strided_slices) > 1 and is_row_splitter:
                self._create_splitter(
                    layer,
                    LayerType.row_splitter,
                    strided_slices,
                    layers_to_remove,
                    new_layers,
                    successors_meta_data,
                )
            else:
                num_slices_in_graph += len(strided_slices)

        # finalize graph manipulation outside graph iteration
        if layers_to_remove:
            for layer in layers_to_remove:
                self._graph.remove_layer(layer)

        for layer in new_layers:
            self._graph.relax_new_layer_into_graph(layer, successors_meta_data)

    def _create_splitter(
        self,
        input_layer,
        splitter_type,
        strided_slices,
        layers_to_remove,
        new_layers,
        successors_meta_data,
    ):
        groups = strided_slices[0].groups
        if splitter_type == LayerType.feature_splitter:
            name_prefix = "feature_splitter_"
            splitter = FeatureSplitterLayer()
            splitter.groups = groups
        elif splitter_type == LayerType.row_splitter and groups == 1:
            name_prefix = "row_splitter_"
            splitter = RowSplitterLayer()
        elif splitter_type == LayerType.width_splitter and groups == 1:
            name_prefix = "width_splitter_"
            splitter = WidthSplitterLayer()
        else:
            raise BackendFuserException(f"Invalid splitter layer type {splitter_type.value}")

        selected_slice = strided_slices[0]
        original_names = []
        for slice_layer in strided_slices:
            original_names.extend(slice_layer.original_names)

        old_name = selected_slice.name
        splitter.index = self._graph.get_next_index()
        splitter.name = name_prefix + old_name
        splitter.inputs = selected_slice.inputs
        splitter.input_shapes = selected_slice.input_shapes
        splitter.original_names = original_names

        self._graph.add_node(splitter)
        self._graph.add_edge(input_layer, splitter)
        new_layers.append(splitter)

        for slice_idx, slice_layer in enumerate(strided_slices):
            slice_succ = self._get_slice_successor(slice_layer, new_layers, splitter.index + slice_idx + 1)
            layers_to_remove[slice_layer] = slice_succ if slice_succ.op == LayerType.shortcut else splitter
            splitter.append_output_shape(slice_layer.output_shape)
            splitter.append_output_layer(slice_succ.name)

            # saving successors' data, which will be needed after deleting them. the new layer will
            # then use this information to update its own location and parameters in the model.
            if slice_succ.name in successors_meta_data:
                idx = successors_meta_data[slice_succ.name]["inputs"].index(slice_layer.name)
                successors_meta_data[slice_succ.name]["inputs"][idx] = splitter.name
                successors_meta_data[slice_succ.name]["input_indices"][idx] = splitter.index
                successors_meta_data[slice_succ.name]["input_shapes"][idx] = slice_layer.output_shape
            else:
                succ_inputs = list(slice_succ.inputs)
                succ_input_indices = list(slice_succ.input_indices)
                succ_input_shapes = list(slice_succ.input_shapes)
                idx = succ_inputs.index(slice_layer.name)
                succ_inputs[idx] = splitter.name
                succ_input_indices[idx] = splitter.index
                succ_input_shapes[idx] = slice_layer.output_shape
                successors_meta_data.update(
                    {
                        slice_succ.name: {
                            "inputs": succ_inputs.copy(),
                            "input_indices": succ_input_indices.copy(),
                            "input_shapes": succ_input_shapes.copy(),
                        },
                    },
                )

            if splitter_type == LayerType.feature_splitter:
                # give each slice a unique split index, or keep the same index if the features split is the same
                # slices are promised to be ordered by the start of the features slice
                if (
                    len(splitter.split_indices) == 0
                    or slice_layer.features_slice[0] != strided_slices[slice_idx - 1].features_slice[0]
                ):
                    splitter.split_indices.append(len(splitter.split_indices))
                else:
                    splitter.split_indices.append(splitter.split_indices[-1])

            self._graph.add_edge(splitter, slice_succ)
            self._graph.remove_edge(input_layer, slice_layer)
            self._graph.remove_edge(slice_layer, slice_succ)

        self._logger.debug(f"Fused strided slices to a new splitter layer after {input_layer.name}.")

    def _get_slice_successor(self, slice_layer, new_layers, index):
        succs = list(self._graph.successors(slice_layer))
        if len(succs) == 1:
            return succs[0]

        # adds shortcut in case of multiple outputs of slice layer
        shortcut = ShortcutLayer()
        shortcut.index = index
        shortcut.name = f"shortcut_{slice_layer.name}"
        shortcut.inputs = [slice_layer.name]
        shortcut.input_indices = [slice_layer.index]
        shortcut.input_shapes = [slice_layer.output_shape]

        self._graph.add_node(shortcut)
        self._graph.add_edge(slice_layer, shortcut)

        for succ in succs:
            succ.replace_input_layer(slice_layer.name, shortcut.name)
            succ.replace_input_index(slice_layer.index, shortcut.index)
            self._graph.remove_edge(slice_layer, succ)
            self._graph.add_edge(shortcut, succ)
            shortcut.append_output_layer(succ.name)
            shortcut.append_output_index(succ.index)
            index = succ.inputs.index(shortcut.name)
            shortcut.append_output_shape(succ.input_shapes[index])

        new_layers.append(shortcut)
        self._logger.debug(f"Added shortcut layer after slice layer {slice_layer.name}.")
        return shortcut

    def _can_convert_succs_to_feature_splitter(self, layer, succs):
        if len(succs) < 2 or layer.op in [
            LayerType.feature_splitter,
            LayerType.spatial_splitter,
            LayerType.width_splitter,
        ]:
            return False

        all_slices_succs = []
        features_covered_by_slices = [False] * layer.output_shapes[0][-1]
        groups = None
        for succ in succs:
            if groups and succ.groups != groups:
                return False
            groups = succ.groups
            features_per_group = layer.output_shapes[0][-1] // groups
            # check that slicing is only on features
            if len(succ.input_shape) == 4 and layer.output_shape[1:3] != succ.output_shape[1:3]:
                return False

            for group_index in range(groups):
                for i in range(succ.features_slice[0], succ.features_slice[1]):
                    curr_index = group_index * features_per_group + i
                    # if the feature is already covered, then the slices features are overlapping
                    if features_covered_by_slices[curr_index]:
                        return False
                    features_covered_by_slices[curr_index] = True
            all_slices_succs.extend(list(self._graph.successors(succ)))

        # check that none of the slices have the same successor
        if len(set(all_slices_succs)) != len(all_slices_succs):
            return False

        if any(succ.op == LayerType.output_layer for succ in all_slices_succs):
            return False

        # check that all features are covered
        return all(x is True for x in features_covered_by_slices)

    def _convert_slices_to_feature_splitter(self):
        """Find feature splitter implemented by slices and replace with feature split layers."""
        new_layers = []
        layers_to_remove = {}  # mapping of the removed layer to its replacement
        successors_meta_data = {}

        for layer in list(self._graph):
            if layer in layers_to_remove:
                layer = layers_to_remove[layer]

            succs = [x for x in self._graph.successors(layer) if x.op == LayerType.base_slice]
            if not self._can_convert_succs_to_feature_splitter(layer, succs):
                continue

            # sort the slices by slicing start index, so that output_indices are sorted correctly
            strided_slices = sorted(succs, key=lambda x: x.features_slice[0])

            self._create_splitter(
                layer,
                LayerType.feature_splitter,
                strided_slices,
                layers_to_remove,
                new_layers,
                successors_meta_data,
            )

        # finalize graph manipulation outside graph iteration
        if layers_to_remove:
            for layer in layers_to_remove:
                self._graph.remove_layer(layer)

        for layer in new_layers:
            self._graph.relax_new_layer_into_graph(layer, successors_meta_data)

    def _split_l2_normalization_layers(self):
        """Find l2 normalization layers and replace with ew div and reduce l2 or layer normalization layer."""
        successors_meta_data = {}
        layers_to_remove = []
        new_layers = []
        for layer in list(self._graph):
            if layer.op == LayerType.l2_normalization:
                if layer.reduce_axes != [3]:
                    pred = next(iter(self._graph.predecessors(layer)))
                    succs = list(self._graph.successors(layer))

                    reduce_l2 = ReduceL2Layer()
                    reduce_l2.name = layer.name
                    reduce_l2.index = self._graph.get_next_index()
                    reduce_l2.input_shapes = layer.input_shapes
                    reduce_l2.reduce_axes = layer.reduce_axes
                    reduce_l2.output_shapes = [
                        [dim if i not in layer.reduce_axes else 1 for i, dim in enumerate(layer.output_shapes[0])]
                    ]
                    reduce_l2.block_info = (BlockType.L2_NORMALIZATION, layer.name)

                    ew_divider = EWDivLayer()
                    ew_divider.name = f"ew_div_{layer.name}"
                    ew_divider.index = reduce_l2.index + 1
                    ew_divider.input_shapes = pred.output_shapes[pred.outputs.index(layer.name)]
                    ew_divider.append_input_shapes(reduce_l2.output_shapes)
                    ew_divider.output_shapes = reduce_l2.input_shapes
                    ew_divider.block_info = (BlockType.L2_NORMALIZATION, layer.name)
                    for orig_name in layer.original_names:
                        ew_divider.add_original_name(orig_name)

                    reduce_l2.append_input_layer(pred.name)
                    reduce_l2.append_input_index(pred.index)
                    reduce_l2.append_output_layer(ew_divider.name)
                    reduce_l2.append_output_index(ew_divider.index)

                    ew_divider.append_input_layer(pred.name)
                    ew_divider.append_input_layer(reduce_l2.name)
                    ew_divider.append_input_index(pred.index)
                    ew_divider.append_input_index(reduce_l2.index)

                    self._graph.add_node(reduce_l2)
                    self._graph.add_node(ew_divider)

                    for succ in succs:
                        self._graph.remove_edge(layer, succ)
                        self._graph.add_edge(ew_divider, succ)
                        succ.replace_input_layer(layer.name, ew_divider.name)
                        succ.replace_input_index(layer.index, ew_divider.index)
                        HailoNN.update_successors_meta_data(succ, successors_meta_data)

                    self._graph.add_edge(pred, reduce_l2)
                    self._graph.add_edge(pred, ew_divider)
                    self._graph.add_edge(reduce_l2, ew_divider)

                    new_layers.extend([reduce_l2, ew_divider])
                    layers_to_remove.append(layer)
                    self._logger.debug(
                        f"Replaced L2 normalization layer {layer.name} with ew-divider and L2 reducer layers.",
                    )
                else:
                    layer_norm = LayerNormalizationLayer.from_layer(layer)

                    new_layers.append(layer_norm)

                    succs = list(self._graph.successors(layer))
                    for succ in succs:
                        self._fuser_helper.replace_pred(succ, layer, layer_norm)

                    preds = list(self._graph.predecessors(layer))
                    for pred in preds:
                        self._fuser_helper.replace_succ(pred, layer, layer_norm)

                    self._fuser_helper.add_preds(layer_norm, preds)
                    self._fuser_helper.add_succs(layer_norm, succs)
                    layers_to_remove.append(layer)

                    self._logger.debug(
                        f"Replaced L2 normalization layer {layer.name} with layer normalization layer.",
                    )

        # finalize graph manipulation outside graph iteration
        if layers_to_remove:
            for layer in layers_to_remove:
                self._graph.remove_layer(layer)

        for layer in new_layers:
            self._graph.relax_new_layer_into_graph(layer, successors_meta_data)

    def _handle_reduce_sum_mean_layers(self):
        self._modify_reduce_sum_mean_by_axes()
        self._split_reduce_l2_and_sum_square()

    def _modify_width_reduce_sum(self, layer, pred, succs, new_layers):
        # "Hacky" support of reduce sum on W axis as reduce sum on H+W axes with h_groups=h
        out_shape = layer.input_shape.copy()
        out_shape[2] = 1
        spatial_reduce = self._fuser_helper.create_layer(
            ReduceSumLayer,
            self._graph.get_next_index(),
            "spatial",
            layer,
            new_layers,
            [out_shape],
        )
        spatial_reduce.reduce_axes = [1, 2]
        spatial_reduce.height_groups = out_shape[1]
        layer_to_inputs = {spatial_reduce: [pred]}
        layer_to_inputs.update({succ: [spatial_reduce] for succ in succs})
        layer_to_outputs = {pred: [spatial_reduce], spatial_reduce: succs}
        self._fuser_helper.handle_new_preds_succs(layer_to_inputs, layer_to_outputs)

    def _split_reduce_sum_mean_layer(self, layer, pred, succs, new_layers, reduce_type):
        orig_axes = sorted(layer.reduce_axes)
        f_out_shape = [*layer.input_shape[:3], 1]
        idx = self._graph.get_next_index()
        features_reduce = self._fuser_helper.create_layer(
            reduce_type, idx, "features", layer, new_layers, [f_out_shape]
        )
        features_reduce.reduce_axes = [3]

        spatial_out_shape = [[1 if i in orig_axes else x for i, x in enumerate(f_out_shape)]]
        if reduce_type == ReduceMeanLayer:
            spatial_reduce = self._fuser_helper.create_layer(
                PoolingLayer,
                idx + 1,
                "spatial",
                layer,
                new_layers,
                spatial_out_shape,
            )
            kernel_h = f_out_shape[1] if 1 in orig_axes else 1
            kernel_w = f_out_shape[2] if 2 in orig_axes else 1
            self._set_avgpool_attrs(spatial_reduce, kernel_h, kernel_w)
        else:
            spatial_reduce = self._fuser_helper.create_layer(
                reduce_type,
                idx + 1,
                "spatial",
                layer,
                new_layers,
                spatial_out_shape,
            )
            spatial_reduce.reduce_axes = orig_axes[:-1]

        layer_to_inputs = {features_reduce: [pred], spatial_reduce: [features_reduce]}
        layer_to_inputs.update({succ: [spatial_reduce] for succ in succs})
        layer_to_outputs = {pred: [features_reduce], features_reduce: [spatial_reduce], spatial_reduce: succs}
        self._fuser_helper.handle_new_preds_succs(layer_to_inputs, layer_to_outputs)

    def _modify_reduce_sum_mean_by_axes(self):
        layers_to_remove = []
        new_layers = []
        reduce_op_to_type = {LayerType.reduce_sum: ReduceSumLayer, LayerType.reduce_mean: ReduceMeanLayer}
        for layer in list(self._graph):
            if layer.op in reduce_op_to_type and sorted(layer.reduce_axes) in [[1, 2, 3], [1, 3], [2]]:
                pred = next(iter(self._graph.predecessors(layer)))
                succs = list(self._graph.successors(layer))
                if layer.reduce_axes == [2]:  # can only be reduce sum cause reduce mean with axes 2 is avgpool
                    self._modify_width_reduce_sum(layer, pred, succs, new_layers)
                else:
                    self._split_reduce_sum_mean_layer(layer, pred, succs, new_layers, reduce_op_to_type[layer.op])
                for succ in succs:
                    self._graph.remove_edge(layer, succ)
                self._graph.remove_edge(pred, layer)
                layers_to_remove.append(layer)
                self._logger.debug(
                    f"Replaced {layer.op.value} layer {layer.name} on all axes with two reduce sum "
                    "layers (features and spatial).",
                )

        for layer in layers_to_remove:
            self._graph.remove_layer(layer)

    def _handle_interleaved_groups_reduce_layer(self):
        for layer in list(self._graph):
            if layer.op == LayerType.reduce_sum and layer.interleaved_groups:
                pred = next(iter(self._graph.predecessors(layer)))
                if pred.op == LayerType.concat and pred.group_sizes is None:
                    pred.group_sizes = [1] * layer.groups
                    layer.groups_interleaved = False
                else:
                    raise BackendFuserException(
                        f"{layer.full_name_msg} is an interleaved groups reduce layer which is "
                        f"supported only after concat",
                    )

    def _split_reduce_l2_and_sum_square(self):
        """Find reduce l2 and reduce sum square layers and replace with feature multiplier and reduce sum."""
        successors_meta_data = {}
        layers_to_remove = []
        new_layers = []
        for layer in list(self._graph):
            if layer.op in [LayerType.reduce_l2, LayerType.reduce_sum_square]:
                pred = next(iter(self._graph.predecessors(layer)))
                succs = list(self._graph.successors(layer))

                reduce_sum = self._fuser_helper.create_layer(
                    ReduceSumLayer,
                    self._graph.get_next_index(),
                    "reduce_sum",
                    layer,
                    new_layers,
                    layer.output_shapes,
                )
                activation = ActivationType.sqrt if layer.op == LayerType.reduce_l2 else ActivationType.linear
                reduce_sum.activation = activation

                feature_mult = self._fuser_helper.create_layer(
                    FeatureMultiplierLayer,
                    reduce_sum.index + 1,
                    "feature_multiplier",
                    layer,
                    new_layers,
                    [layer.input_shape],
                )
                feature_mult.original_names = []  # assigining original names only for last genereated layer
                feature_mult.feature_multiplier_type = FeatureMultiplierType.square

                pred.replace_output_layer(layer.name, feature_mult.name)
                pred.replace_output_index(layer.index, feature_mult.index)

                reduce_sum.append_input_layer(feature_mult.name)
                reduce_sum.append_input_index(feature_mult.index)
                reduce_sum.input_shapes = feature_mult.output_shapes

                for succ in succs:
                    self._graph.remove_edge(layer, succ)
                    self._graph.add_edge(reduce_sum, succ)
                    succ.replace_input_layer(layer.name, reduce_sum.name)
                    succ.replace_input_index(layer.index, reduce_sum.index)
                    HailoNN.update_successors_meta_data(succ, successors_meta_data)

                self._graph.remove_edge(pred, layer)
                self._graph.add_edge(pred, feature_mult)
                self._graph.add_edge(feature_mult, reduce_sum)

                new_layers.extend([feature_mult, reduce_sum])
                layers_to_remove.append(layer)
                self._logger.debug(
                    f"Replaced {layer.op.value} layer {layer.name} with feature multiplier and reduce sum layers.",
                )

        # finalize graph manipulation outside graph iteration
        if layers_to_remove:
            for layer in layers_to_remove:
                self._graph.remove_layer(layer)

        for layer in new_layers:
            self._graph.relax_new_layer_into_graph(layer, successors_meta_data)

    def _split_multi_layers_activations(self):
        """Split activations that can be broken apart to several layers."""
        successors_meta_data = {}
        layers_to_remove = []
        new_layers = []
        for layer in list(self._graph):
            if layer.op == LayerType.base_activation:
                if layer.activation == ActivationType.swish:
                    if layer.swish_beta == 1:
                        layer.activation = ActivationType.silu
                        layer.swish_beta = None

        # finalize graph manipulation outside graph iteration
        if layers_to_remove:
            for layer in layers_to_remove:
                self._graph.remove_layer(layer)

        for layer in new_layers:
            self._graph.relax_new_layer_into_graph(layer, successors_meta_data)

    def _fuse_padding(self):
        """Fuse pad layers to successors if padding scheme is supported or is null."""
        layers_to_remove = []

        for layer in list(self._graph):
            if layer.op == LayerType.external_pad:
                pad_layer = layer
                pad_succs = list(self._graph.successors(pad_layer))
                if (
                    pad_layer.top
                    == pad_layer.bottom
                    == pad_layer.left
                    == pad_layer.right
                    == pad_layer.front
                    == pad_layer.back
                    == 0
                ):
                    # Found null padding, should be removed
                    self._fuser_helper.remove_layer(
                        pad_layer,
                        layers_to_remove,
                        fuse_to_succ=False,
                        fuse_to_pred=False,
                    )
                    self._logger.debug(f"Removed null padding layer {pad_layer.name}.")

                elif len(pad_succs) == 1:
                    pad_succ = pad_succs[0]
                    if self._can_fuse_padding_to_succ(pad_layer, pad_succ):
                        padding, _ = self._get_padding_scheme(pad_layer, pad_succ)
                        # Found matching padding type - fusing to the layer
                        if padding is not None:
                            self._fuse_padding_to_succ(pad_layer, padding, pad_succ, layers_to_remove)
                            self._logger.debug(f"Fused padding layer {pad_layer.name} to layer {pad_succ.name}.")

                    elif pad_succ.op == LayerType.feature_splitter:
                        split_succs = list(self._graph.successors(pad_succ))
                        if self._can_fuse_padding_after_feature_splitter(pad_layer, split_succs):
                            # Found matching padding type - fusing to the split successors
                            self._fuse_padding_after_feature_splitter(
                                pad_layer,
                                pad_succ,
                                split_succs,
                                layers_to_remove,
                            )
                            self._logger.debug(
                                f"Fused padding layer {pad_layer.name} to layers after feature splitter "
                                f"{pad_succ.name}.",
                            )

        for pad_layer in layers_to_remove:
            self._graph.remove_layer(pad_layer)

    def _apply_avgpool_correction(self):
        for layer in list(self._graph):
            if (
                isinstance(layer, PoolingLayer)
                and layer.op == LayerType.avgpool
                and layer.padding != PaddingType.valid
                and not layer.required_padding_correction
                and not layer.count_include_pad
            ):
                layer.required_padding_correction = True
                dilations = layer.dilations if layer.dilations else [1, 1, 1, 1]
                end_h, begin_h, end_w, begin_w = calculate_padding(
                    layer.padding,
                    layer.kernel_height,
                    layer.kernel_width,
                    layer.stride_height,
                    layer.stride_width,
                    layer.input_height,
                    layer.input_width,
                    dilations,
                )
                self._apply_padding_correction(layer, (begin_h, end_h, begin_w, end_w))

    def _handle_pooling_ceil_mode(self):
        for layer in list(self._graph):
            if isinstance(layer, PoolingLayer) and layer.ceil_mode:
                # adds external padding layer to fix the output shape
                padding_vals = self._calculate_ceil_padding_schema(
                    layer.input_shape[1:3],
                    layer.kernel_shape[1:3],
                    layer.strides[1:3],
                )
                if padding_vals[1] != 0 or padding_vals[3] != 0:
                    # external padding is needed
                    external_padding = ExternalPadLayer()
                    external_padding.name = f"{layer.name}_ceil_mode_padding"
                    external_padding.index = self._graph.get_next_index()
                    external_padding.original_names = layer.original_names.copy()
                    external_padding.set_pad(padding_vals)

                    pred = next(iter(self._graph.predecessors(layer)))
                    pred.replace_output_index(layer.index, external_padding.index)
                    pred.replace_output_layer(layer.name, external_padding.name)

                    external_padding.inputs = layer.inputs.copy()
                    external_padding.input_shapes = layer.input_shapes.copy()
                    external_padding.input_indices = layer.input_indices.copy()

                    external_padding.outputs = [layer.name]
                    external_padding.output_indices = [layer.index]
                    external_padding.output_shapes = [
                        external_padding.input_shape[0],
                        external_padding.input_shape[1] + padding_vals[1],
                        external_padding.input_shape[2] + padding_vals[3],
                        external_padding.input_shape[3],
                    ]

                    layer.inputs = [external_padding.name]
                    layer.input_indices = [external_padding.index]
                    layer.input_shapes = external_padding.output_shapes

                    # identifying the ceil mode is no longer needed because we add external padding to the graph to
                    # fix that
                    layer.ceil_mode = False

                    for succ in list(self._graph.successors(layer)):
                        succ.input_shapes[succ.inputs.index(layer.name)] = layer.output_shapes

                    self._graph.add_node(external_padding)
                    self._graph.add_edge(pred, external_padding)
                    self._graph.add_edge(external_padding, layer)
                    self._graph.remove_edge(pred, layer)

                    if layer.op == LayerType.avgpool:
                        # applying avgpool padding correction
                        bottom, top, right, left = padding_vals[1], padding_vals[0], padding_vals[3], padding_vals[2]
                        self._apply_padding_correction(layer, [bottom, top, right, left], has_external_pad=True)

    def _handle_disparity_resize(self):
        """
        this function handles the case of resizing the disparity by a factor of int.
        it adds groups conv1x1 to the graph to simulate the resize operation.
        """
        new_layers = []
        successors_meta_data = {}
        layers_to_remove = []
        for layer in list(self._graph):
            if layer.op == LayerType.resize and len(layer.d_ratios) == 1 and layer.d_ratios != [1.0]:
                if not layer.d_ratios[0].is_integer():
                    raise BackendFuserException(
                        f"Layer {layer.name} has a non-integer disparity factor {layer.d_ratios[0]}."
                        "Performing resize on the disparity axis by a non-integer factor is not supported",
                    )
                matmul_succ = look_for_node(
                    self._graph,
                    layer,
                    [FwdChainNode(op=LayerType.matmul)],
                    exact_match=True,
                )

                if matmul_succ and matmul_succ.groups == layer.input_disparity * layer.d_ratios[0]:
                    # instead of resizing the disparity, we can just update the matmul tiles attribute.
                    # then the resizing is done in the kernel of the matmul
                    matmul_succ.input_tiles[matmul_succ.inputs.index(layer.name)][-1] = int(layer.d_ratios[0])
                    self._fuser_helper.remove_layer(layer, layers_to_remove=layers_to_remove)
                    continue

                groups, strides, dilations = layer.input_disparity, [1, 1, 1, 1], [1, 1, 1, 1]
                conv1x1_layer = Conv2DLayer()
                conv1x1_layer.original_names = layer.original_names
                conv1x1_layer.index = self._graph.get_next_index()
                conv1x1_layer.name = f"disparity_resize{layer.index}"
                kernel = np.concatenate(
                    [np.identity(layer.input_features // layer.input_disparity) for _ in range(int(layer.d_ratios[0]))],
                    axis=1,
                )  # input_features = real_features * input_disparity. dividing to extract the actual features
                kernel = np.tile(kernel, groups)
                reshaped_kernel = np.reshape(kernel, [1, 1, kernel.shape[0], kernel.shape[1]])

                conv1x1_layer.kernel = reshaped_kernel
                # kernel shape is [k_h, k_w, f_in * groups, f_out]
                conv1x1_layer.kernel_shape = [
                    dim if i != 2 else dim * groups for i, dim in enumerate(reshaped_kernel.shape)
                ]
                conv1x1_layer.strides = strides
                conv1x1_layer.dilations = dilations
                conv1x1_layer.padding = PaddingType.valid
                conv1x1_layer.groups = groups
                conv1x1_layer.activation = ActivationType.linear

                old_output_index = layer.output_indices[0]
                self._graph.push_layer(conv1x1_layer, [layer])
                layer.replace_output_index(old_output_index, conv1x1_layer.index)
                conv1x1_layer.append_output_index(old_output_index)
                HailoNN.update_successors_meta_data(layer, successors_meta_data)
                new_layers.extend([conv1x1_layer])

        for layer in new_layers:
            self._graph.relax_new_layer_into_graph(layer, successors_meta_data)
        for layer_to_remove in layers_to_remove:
            self._graph.remove_layer(layer_to_remove)

    def _calculate_ceil_padding_schema(self, input_shape, kernel_shape, strides):
        # this function extracts the start index of the last pooling sliding window
        # to calculate how many paddings rows/cols are required to complete the window
        padding_val = []

        for dim in range(2):
            padding_val.append(0)
            start_index = 0
            while start_index + kernel_shape[dim] < input_shape[dim]:
                start_index += strides[dim]

            padding_val.append(start_index + kernel_shape[dim] - input_shape[dim])

        return padding_val

    def _can_fuse_padding_to_succ(self, pad_layer, succ):
        # Only pooling and conv layer support internal padding
        if succ.op not in [
            LayerType.base_conv,
            LayerType.base_deconv,
            LayerType.base_dw,
            LayerType.maxpool,
            LayerType.avgpool,
        ]:
            return False

        # Unable to fuse padding to layer that already have padding
        if succ.padding != PaddingType.valid:
            return False

        # Unable to fuse features padding
        if pad_layer.front != 0 or pad_layer.back != 0:
            return False

        return True

    def _can_fuse_padding_after_feature_splitter(self, pad_layer, split_succs):
        # Validate we can fuse padding to all split successors
        for split_succ in split_succs:
            if not self._can_fuse_padding_to_succ(pad_layer, split_succ):
                return False

            # The external padding don't match known type of padding
            padding_scheme, _ = self._get_padding_scheme(pad_layer, split_succ)
            if padding_scheme is None:
                return False

        return True

    @staticmethod
    def _get_padding_scheme(pad_layer, succ):
        """Returns the padding schema defined by the external padding layer, when no schema match returns None."""
        if succ.op == LayerType.base_deconv and (
            pad_layer.top == pad_layer.bottom == pad_layer.left == pad_layer.right == 1
        ):
            return PaddingType.deconv

        if succ.op in [LayerType.maxpool, LayerType.avgpool]:
            dil_kernel_h = succ.kernel_height
            dil_kernel_w = succ.kernel_width
        else:
            dil_kernel_h = succ.kernel_height + ((succ.kernel_height - 1) * (succ.dilations[1] - 1))
            dil_kernel_w = succ.kernel_width + ((succ.kernel_width - 1) * (succ.dilations[2] - 1))

        strides_h = succ.strides[1]
        strides_w = succ.strides[2]
        input_h = pad_layer.input_height
        input_w = pad_layer.input_width

        pad_total_h = dil_kernel_h - strides_h if input_h % strides_h == 0 else dil_kernel_h - input_h % strides_h
        pad_total_h = max(pad_total_h, 0)

        pad_total_w = dil_kernel_w - strides_w if input_w % strides_w == 0 else dil_kernel_w - input_w % strides_w
        pad_total_w = max(pad_total_w, 0)

        pad_beg_h = int(pad_total_h / 2)
        pad_end_h = pad_total_h - pad_beg_h
        pad_beg_w = int(pad_total_w / 2)
        pad_end_w = pad_total_w - pad_beg_w

        # Check if padding is equivalent to same_tensorflow
        if pad_beg_h == pad_layer.top and pad_beg_w == pad_layer.left:
            if (pad_end_h == pad_layer.bottom and pad_end_w == pad_layer.right) or (
                (
                    int((pad_layer.top + input_h + pad_layer.bottom - dil_kernel_h) / strides_h)
                    == int((pad_beg_h + input_h + pad_end_h - dil_kernel_h) / strides_h)
                )
                and (
                    int((pad_layer.left + input_w + pad_layer.right - dil_kernel_w) / strides_w)
                    == int((pad_beg_w + input_w + pad_end_w - dil_kernel_w) / strides_w)
                )
            ):
                return PaddingType.same_tensorflow, (pad_beg_h, pad_end_h, pad_beg_w, pad_end_w)

        # Check if padding is equivalent to same
        if pad_end_h == pad_layer.top and pad_end_w == pad_layer.left:
            if (pad_beg_h == pad_layer.bottom and pad_beg_w == pad_layer.right) or (
                (
                    int((pad_layer.top + input_h + pad_layer.bottom - dil_kernel_h) / strides_h)
                    == int((pad_beg_h + input_h + pad_end_h - dil_kernel_h) / strides_h)
                )
                and (
                    int((pad_layer.left + input_w + pad_layer.right - dil_kernel_w) / strides_w)
                    == int((pad_beg_w + input_w + pad_end_w - dil_kernel_w) / strides_w)
                )
            ):
                return PaddingType.same, (pad_beg_h, pad_end_h, pad_beg_w, pad_end_w)

        return None, None

    def _fuse_padding_to_succ(self, pad_layer, padding, succ, layers_to_remove):
        self._fuse_and_validate_padding(succ, pad_layer, padding)
        self._fuser_helper.remove_layer(pad_layer, layers_to_remove, fuse_to_succ=True, fuse_to_pred=False)

    def _fuse_padding_after_feature_splitter(self, pad_layer, succ, split_succs, layers_to_remove):
        new_output_shapes = []
        for shape in succ.output_shapes:
            new_output_shapes.append([-1, pad_layer.input_height, pad_layer.input_width, shape[-1]])
        succ.output_shapes = new_output_shapes

        for split_succ in split_succs:
            split_succ.input_shape = [-1, pad_layer.input_height, pad_layer.input_width, split_succ.input_features]
            padding, _ = self._get_padding_scheme(pad_layer, split_succ)
            self._fuse_and_validate_padding(split_succ, pad_layer, padding)

        self._fuser_helper.remove_layer(pad_layer, layers_to_remove)

    def _fuse_and_validate_padding(self, dst_layer, pad_layer, padding):
        dst_layer.padding = padding
        dst_layer.padding_const_value = pad_layer.padding_const_value

        if dst_layer.op == LayerType.avgpool and pad_layer.original_names != dst_layer.original_names:
            dst_layer.count_include_pad = True

    def _apply_padding_correction(self, avgpool_layer, padding_layout, has_external_pad=False):
        h, w = avgpool_layer.kernel_shape[1:3]
        bottom, top, right, left = padding_layout
        # extracts the shape of the avgpool's input layer (without the padding)
        # in case of external padding it's the shape of the padding's input otherwise it's the avgpool's input shape
        original_input_size_without_padding = (
            next(iter(self._graph.predecessors(avgpool_layer))).input_shape
            if has_external_pad
            else avgpool_layer.input_shapes[-1]
        )
        padded_matrix = np.ones(original_input_size_without_padding[1:3])

        if bottom > 0:
            bottom_pad = np.ones((bottom, padded_matrix.shape[1])) * avgpool_layer.padding_const_value
            padded_matrix = np.concatenate([padded_matrix, bottom_pad], axis=0)

        if top > 0:
            top_pad = np.ones((top, padded_matrix.shape[1])) * avgpool_layer.padding_const_value
            padded_matrix = np.concatenate([top_pad, padded_matrix], axis=0)

        if right > 0:
            right_pad = np.ones((padded_matrix.shape[0], right)) * avgpool_layer.padding_const_value
            padded_matrix = np.concatenate([padded_matrix, right_pad], axis=1)

        if left > 0:
            left_pad = np.ones((padded_matrix.shape[0], left)) * avgpool_layer.padding_const_value
            padded_matrix = np.concatenate([left_pad, padded_matrix], axis=1)

        padded_matrix = np.expand_dims(padded_matrix, axis=0)
        padded_matrix = np.expand_dims(padded_matrix, axis=-1)

        # the correction matrix counts the number of element that supposed to be counted in each pooling window then
        # takes the inverse of the count value (it's the dividing value in avg pool) and finally multiply the matrix
        # by h*w for canceling the wrong denominator done in our wrong avg pool

        # takes the average of each window and multiplying by number of elements -> number of element in each window
        correction_matrix = (
            tf.keras.layers.AveragePooling2D(
                pool_size=avgpool_layer.kernel_shape[1:3],
                strides=avgpool_layer.strides[1:3],
                padding="valid",
            )(padded_matrix)
            * h
            * w
        )
        # rescale correction matrix
        correction_matrix = (h * w) / correction_matrix

        successors_meta_data = {}
        ew_mult = EWMultLayer()
        ew_mult.name = f"{avgpool_layer.name}_ew_mult"
        ew_mult.original_names = avgpool_layer.original_names
        ew_mult.index = self._graph.get_next_index()

        const_layer = ConstInputLayer.create(None, [avgpool_layer.output_shapes[0]], correction_matrix[0])
        const_layer.name = f"{avgpool_layer.name}_padding_correction_matrix"
        const_layer.index = ew_mult.index + 1
        const_layer.outputs = [ew_mult.name]
        const_layer.output_indices = [ew_mult.index]

        ew_mult.inputs = [const_layer.name, avgpool_layer.name]
        ew_mult.input_shapes = [const_layer.output_shape, avgpool_layer.output_shape]
        ew_mult.input_list = [const_layer, avgpool_layer]
        ew_mult.input_indices = [const_layer.index, avgpool_layer.index]

        ew_mult.outputs = avgpool_layer.outputs.copy()
        ew_mult.output_indices = avgpool_layer.output_indices.copy()
        ew_mult.output_shapes = avgpool_layer.output_shapes.copy()

        avgpool_layer.outputs = [ew_mult.name]
        avgpool_layer.output_indices = [ew_mult.index]
        # output shape stays the same

        self._graph.add_node(const_layer)
        self._graph.add_node(ew_mult)

        for succ in list(self._graph.successors(avgpool_layer)):
            succ.inputs[succ.inputs.index(avgpool_layer.name)] = ew_mult.name
            succ.input_indices[succ.input_indices.index(avgpool_layer.index)] = ew_mult.index
            # input shape stays the same
            self._graph.remove_edge(avgpool_layer, succ)
            self._graph.add_edge(ew_mult, succ)

        self._graph.add_edge(const_layer, ew_mult)
        self._graph.add_edge(avgpool_layer, ew_mult)
        HailoNN.update_successors_meta_data(ew_mult, successors_meta_data)
        HailoNN.update_successors_meta_data(const_layer, successors_meta_data)

        for layer in [const_layer, ew_mult]:
            self._graph.relax_new_layer_into_graph(layer, successors_meta_data)

    def _handle_group_convolutions(self):
        """Find group convolutions and switch to single conv layer with groups param."""
        layers_to_remove = []
        new_layers = []
        successors_meta_data = {}

        for layer in list(self._graph):
            if layer.op != LayerType.concat:
                continue

            concat = layer
            concat_preds = list(self._graph.predecessors(concat))
            groups = len(concat_preds)
            if groups < self.GROUP_CONV_MIN_NUM_OF_GROUPS:
                continue

            first_conv = concat_preds[0]
            if first_conv.op == LayerType.bias_add:
                first_conv = next(iter(self._graph.predecessors(first_conv)))

            # If predecessor is not conv/deconv or bias_add, then there is no need to look at it's predecessor
            elif first_conv.op not in [LayerType.base_conv, LayerType.base_deconv]:
                continue

            feature_splitter = next(iter(self._graph.predecessors(first_conv)))

            # Check that the concat layer preds are all conv layers with the same input layer and kernel shape
            if not self._is_group_conv(feature_splitter, first_conv, concat_preds, concat):
                continue

            # group conv found - fuse concat and all group conv layers to one conv/deconv layer.
            self._create_group_conv(
                feature_splitter,
                first_conv,
                concat_preds,
                concat,
                groups,
                layers_to_remove,
                new_layers,
                successors_meta_data,
            )

        for layer in layers_to_remove:
            self._graph.remove_layer(layer)

        for layer in new_layers:
            self._graph.relax_new_layer_into_graph(layer, successors_meta_data)

    def _create_group_conv(
        self,
        feature_splitter,
        first_conv,
        concat_preds,
        concat,
        groups,
        layers_to_remove,
        new_layers,
        successors_meta_data,
    ):
        old_layers = [concat, feature_splitter]
        for concat_pred in concat_preds:
            if concat_pred.op == LayerType.bias_add:
                old_layers.append(concat_pred)
                concat_pred = next(iter(self._graph.predecessors(concat_pred)))
            old_layers.append(concat_pred)

        group_conv = Conv2DLayer()
        group_conv.name = f"group_conv_{concat.name}"
        group_conv.index = self._graph.get_next_index()
        group_conv.op = first_conv.op  # determine if conv or deconv according to the repr_conv layer
        group_conv.kernel_shape = [*first_conv.kernel_shape[:2], *[x * groups for x in first_conv.kernel_shape[2:]]]
        group_conv.dilations = first_conv.dilations
        group_conv.strides = first_conv.strides
        group_conv.padding = first_conv.padding
        group_conv.groups = groups
        group_conv.inputs = feature_splitter.inputs
        group_conv.input_indices = feature_splitter.input_indices
        group_conv.input_shapes = feature_splitter.input_shapes
        group_conv.outputs = concat.outputs
        group_conv.output_indices = concat.output_indices
        group_conv.output_shapes = concat.output_shapes
        group_conv.move_group_conv_params(old_layers)

        self._graph.add_node(group_conv)

        for pred in list(self._graph.predecessors(feature_splitter)):
            self._graph.remove_edge(pred, feature_splitter)
            self._graph.add_edge(pred, group_conv)
            pred.replace_output_layer(feature_splitter.name, group_conv.name)
            pred.replace_output_index(feature_splitter.index, group_conv.index)

        for succ in list(self._graph.successors(concat)):
            self._graph.remove_edge(concat, succ)
            self._graph.add_edge(group_conv, succ)
            succ.replace_input_layer(concat.name, group_conv.name)
            succ.replace_input_index(concat.index, group_conv.index)
            HailoNN.update_successors_meta_data(succ, successors_meta_data)
        new_layers.append(group_conv)
        layers_to_remove.extend(old_layers)

        self._logger.debug(f"Fused group-conv layer at {group_conv.name}.")

    def _is_group_conv(self, feature_splitter, first_conv, concat_preds, concat):
        for pred in concat_preds:
            if pred.op == LayerType.bias_add:
                pred = next(iter(self._graph.predecessors(pred)))
            layer_preds = list(self._graph.predecessors(pred))

            if (
                pred.op not in [LayerType.base_conv, LayerType.base_deconv]
                or pred.op != first_conv.op
                or pred.input_shapes != first_conv.input_shapes
                or pred.kernel_shape != first_conv.kernel_shape
                or pred.dilations != first_conv.dilations
                or pred.strides != first_conv.strides
                or pred.padding != first_conv.padding
                or pred.groups > 1
                or len(layer_preds) != 1
                or feature_splitter != layer_preds[0]
                or feature_splitter.op != LayerType.feature_splitter
            ):
                self._logger.debug(
                    f"Failed to fuse conv-group layer. The condition is that concat layer {concat.name} "
                    f"should have {self.GROUP_CONV_MIN_NUM_OF_GROUPS} or more predecessors, all of which"
                    f" are conv/deconv ops, with identical parameters (input shape, kernel shape, "
                    f"strides, dilations). In addition, all conv/deconv layers must be successors of a "
                    f"single feature splitter layer.",
                )
                return False

        return True

    def _fuse_output_format_conversion(self):
        """
        Fuse conv1x1, normalization, transpose, space to depth and spatial flatten
        """
        layers_to_remove = []
        for layer in list(self._output_graph):
            if layer.op == LayerType.format_conversion:
                conversion_type = layer.conversion_type

                if (
                    conversion_type == FormatConversionType.transpose_width_features and layer.groups <= 1
                ) or layer.is_flatten_reshape:
                    preds = list(self._output_graph.predecessors(layer))
                    pred = preds[0]
                    pred_of_pred = list(self._output_graph.predecessors(pred))
                    if not pred_of_pred:
                        continue

                    is_pred_conv1x1s1 = (
                        pred.op in [LayerType.conv, LayerType.dw]
                        and pred.kernel_height == 1
                        and pred.kernel_width == 1
                        and pred.stride_height == 1
                        and pred.stride_width == 1
                    )

                    if (
                        len(preds) == 1
                        and len(list(self._output_graph.successors(pred))) == 1
                        and (
                            (
                                (pred.op == LayerType.normalization or is_pred_conv1x1s1)
                                and conversion_type == FormatConversionType.transpose_width_features
                            )
                            or (
                                pred_of_pred[0].op == LayerType.space_to_depth
                                and layer.is_flatten_reshape
                                and pred_of_pred[0].space_to_depth_type == SpaceToDepthType.serial
                            )
                        )
                    ):
                        # set flag according to output conversion type
                        if conversion_type == FormatConversionType.transpose_width_features:
                            pred.transpose_output_width_features = not pred.transpose_output_width_features
                        elif layer.is_flatten_reshape:
                            # HACK: space to depth is the only layer support spatial flatten output.
                            # ResMLP use spatial flatten output after conv16x16 -> space_to_depth + conv1x1 (SDK-24136)
                            pred_of_pred[0].spatial_flatten_output = not pred_of_pred[0].spatial_flatten_output

                        self._fuser_helper.remove_layer(layer, layers_to_remove)

        for layer_to_remove in layers_to_remove:
            self._output_graph.remove_layer(layer_to_remove)

    def _handle_depth_to_space_activations(self):
        """
        fuses depth to space layers' activation to their preceded conv layer
        """
        layers_to_remove = []
        for layer in list(self._output_graph):
            if layer.op == LayerType.depth_to_space:
                succs = list(self._output_graph.successors(layer))
                preds = list(self._output_graph.predecessors(layer))
                if (
                    len(preds) == 1
                    and preds[0].op == LayerType.conv
                    and preds[0].activation == ActivationType.linear
                    and len(succs) == 1
                    and succs[0].op == LayerType.activation
                ):
                    pred = preds[0]
                    succ = succs[0]

                    pred.activation = succ.activation
                    pred.move_params(succ)
                    layers_to_remove.append(succ)
                    self._output_graph.remove_edge(layer, succ)

                    succs_of_succ = list(self._output_graph.successors(succ))
                    for succ_of_succ in succs_of_succ:
                        layer.replace_output_layer(succ.name, succ_of_succ.name)
                        layer.replace_output_index(succ.index, succ_of_succ.index)

                        succ_of_succ.replace_input_index(succ.index, layer.index)
                        succ_of_succ.replace_input_layer(succ.name, layer.name)
                        self._output_graph.add_edge(layer, succ_of_succ)

        for layer_to_remove in layers_to_remove:
            self._output_graph.remove_layer(layer_to_remove)

    def _handle_multiple_output_splits(self):
        """
        Add shortcut layers for split with more than 1 output per split
        """
        new_layers = []
        successors_meta_data = {}

        for layer in list(self._output_graph):
            if layer.op == LayerType.feature_splitter and len(layer.split_indices) > 0:
                # edge case: the used split was parsed as an activation that has an elementwise operation with itself,
                # for example GeLU which is x*f(x), both inputs to the mult are from our split.
                succs = list(self._output_graph.successors(layer))
                if (
                    len(succs) < len(layer.outputs)
                    and len(succs) == len(layer.output_indices)
                    and any(succ.op == LayerType.activation for succ in succs)
                ):
                    act_name = [succ.name for succ in succs if succ.op == LayerType.activation][-1]
                    idx_to_remove = [i for i, x in enumerate(layer.outputs) if x == act_name][-1]
                    layer.outputs.pop(idx_to_remove)
                    layer.split_indices.pop(idx_to_remove)
                    layer.output_shapes.pop(idx_to_remove)

                split_counts = Counter(layer.split_indices)
                for split_index, count in split_counts.items():
                    if count > 1:
                        # the current split is used more than one time, creating new shortcut to aggregate the
                        # successors with the current split_index
                        shortcut_layer = ShortcutLayer()
                        shortcut_layer.index = self._output_graph.get_next_index()
                        shortcut_layer.name = f"shortcut_layer_{layer.name}_{split_index}"
                        self._output_graph.add_node(shortcut_layer)
                        shortcut_layer.append_input_index(layer.index)
                        shortcut_layer.append_input_layer(layer.name)
                        shortcut_layer.move_params(layer)
                        new_layers.append(shortcut_layer)

                        succs = list(self._output_graph.successors(layer))
                        self._output_graph.add_edge(layer, shortcut_layer)
                        for succ in succs:
                            succ_index_in_list = layer.output_indices.index(succ.index)
                            if layer.split_indices[succ_index_in_list] == split_index:
                                # the current layer belongs to the current split_index
                                # updating successor input details
                                succ.replace_input_index(layer.index, shortcut_layer.index)
                                succ.replace_input_layer(layer.name, shortcut_layer.name)
                                # input shape was not changed

                                self._output_graph.add_edge(shortcut_layer, succ)
                                self._output_graph.remove_edge(layer, succ)

                                if not shortcut_layer.input_shape:
                                    # this is the first layer in split_index cluster
                                    shortcut_layer.input_shape = layer.output_shapes[succ_index_in_list]
                                    layer.replace_output_index(succ.index, shortcut_layer.index)
                                    layer.replace_output_layer(succ.name, shortcut_layer.name)
                                else:
                                    # the details were updated in the phase before, no need to add or replace
                                    del layer.split_indices[succ_index_in_list]
                                    del layer.output_shapes[succ_index_in_list]
                                    layer.output_indices.remove(succ.index)
                                    layer.outputs.remove(succ.name)

                                # updating shortcut_layer->succ details
                                shortcut_layer.append_output_shape(shortcut_layer.input_shape)
                                shortcut_layer.append_output_index(succ.index)
                                shortcut_layer.append_output_layer(succ.name)
                                HailoNN.update_successors_meta_data(succ, successors_meta_data)

                self._logger.debug(
                    f"Inserted a shortcut layer to enable multiple outputs for split in layer {layer.name}",
                )

        for layer in new_layers:
            self._output_graph.relax_new_layer_into_graph(layer, successors_meta_data)

    def _handle_avgpool1x1s1(self):
        """
        finds avgpool1x1s1 and removes it from the graph
        """
        layers_to_remove = []
        for layer in list(self._graph):
            if (
                layer.op == LayerType.avgpool
                and layer.kernel_width == 1
                and layer.kernel_height == 1
                and layer.stride_height == 1
                and layer.stride_width == 1
            ):
                self._logger.debug(f"avgpool1x1s1 {layer.name} was found. remove it from the graph")
                self._fuser_helper.remove_layer(layer, layers_to_remove)

        for layer_to_remove in layers_to_remove:
            self._graph.remove_layer(layer_to_remove)

    def _split_gcn_pooling_block(self):
        new_layers = []
        layers_to_remove = []

        for layer in list(self._graph):
            if (
                layer.op == LayerType.transpose
                and layer not in layers_to_remove
                and layer.is_first_in_gcn_block(self._graph)
            ):
                transpose = layer
                layers_to_remove.append(transpose)
                avgpool_of_transpose = next(iter(self._graph.predecessors(transpose)))
                concat = next(iter(self._graph.successors(transpose)))
                layers_to_remove.append(concat)
                avgpool2 = next(x for x in self._graph.predecessors(concat) if x.op == LayerType.avgpool)
                conv = next(iter(self._graph.successors(concat)))
                hardswish = next(iter(self._graph.successors(conv)))
                slices = list(self._graph.successors(hardswish))
                if next(iter(self._graph.successors(slices[0]))).op == LayerType.transpose:
                    slice_of_transpose = slices[0]
                    slice2 = slices[1]
                else:
                    slice_of_transpose = slices[1]
                    slice2 = slices[0]
                layers_to_remove.append(slice_of_transpose)
                layers_to_remove.append(slice2)
                last_transpose = next(iter(self._graph.successors(slice_of_transpose)))
                layers_to_remove.append(last_transpose)
                last_conv_of_transpose = next(iter(self._graph.successors(last_transpose)))
                last_conv2 = next(iter(self._graph.successors(slice2)))

                self._graph.remove_edge(avgpool_of_transpose, transpose)
                self._graph.remove_edge(avgpool2, concat)
                self._graph.remove_edge(concat, conv)
                self._graph.remove_edge(hardswish, slice_of_transpose)
                self._graph.remove_edge(hardswish, slice2)
                self._graph.remove_edge(last_transpose, last_conv_of_transpose)
                self._graph.remove_edge(slice2, last_conv2)

                avgpool_of_transpose.replace_output_layer(transpose.name, conv.name)
                avgpool_of_transpose.replace_output_index(transpose.index, conv.index)
                conv_output = avgpool_of_transpose.output_shape[:3] + [conv.output_shape[3]]
                conv.replace_input_index(concat.index, avgpool_of_transpose.index)
                conv.replace_input_shape(concat.name, avgpool_of_transpose.output_shape)
                conv.replace_input_layer(concat.name, avgpool_of_transpose.name)
                conv.output_shapes = [conv_output for _ in range(conv.output_copies)]
                hardswish.replace_input_shape(conv.name, conv_output)
                hardswish.output_indices = [last_conv_of_transpose.index]
                hardswish.outputs = [last_conv_of_transpose.name]
                hardswish.output_shapes = [conv_output]
                last_conv_of_transpose.replace_input_index(last_transpose.index, hardswish.index)
                last_conv_of_transpose.replace_input_shape(last_transpose.name, conv_output)
                last_conv_of_transpose.replace_input_layer(transpose.name, hardswish.name)

                self._graph.add_edge(avgpool_of_transpose, conv)
                self._graph.add_edge(hardswish, last_conv_of_transpose)

                conv_dup = Conv2DLayer.from_layer(conv)
                conv_dup.move_params(conv)
                conv_dup.index = self._graph.get_next_index()
                conv_dup.name = f"{conv_dup.name}_dup"
                new_layers.append(conv_dup)
                self._graph.add_node(conv_dup)
                hardswish_dup = ActivationLayer.from_layer(hardswish)
                hardswish_dup.move_params(hardswish)
                hardswish_dup.index = self._graph.get_next_index()
                hardswish_dup.name = f"{hardswish_dup.name}_dup"
                new_layers.append(hardswish_dup)
                self._graph.add_node(hardswish_dup)

                avgpool2.replace_output_layer(concat.name, conv_dup.name)
                avgpool2.replace_output_index(concat.index, conv_dup.index)
                conv_dup_output = avgpool2.output_shape[:3] + [conv_dup.output_shape[3]]
                conv_dup.replace_input_index(avgpool_of_transpose.index, avgpool2.index)
                conv_dup.replace_input_shape(avgpool_of_transpose.name, avgpool2.output_shape)
                conv_dup.replace_input_layer(avgpool_of_transpose.name, avgpool2.name)
                conv_dup.output_shapes = [conv_dup_output for _ in range(conv_dup.output_copies)]
                conv_dup.outputs = [hardswish_dup.name]
                conv_dup.replace_output_index(hardswish.index, hardswish_dup.index)
                hardswish_dup.replace_input_index(conv.index, conv_dup.index)
                hardswish_dup.replace_input_shape(conv.name, conv_dup_output)
                hardswish_dup.replace_input_layer(conv.name, conv_dup.name)
                hardswish_dup.output_indices = [last_conv2.index]
                hardswish_dup.output_shapes = [conv_dup_output]
                hardswish_dup.outputs = [last_conv2.name]
                last_conv2.replace_input_index(slice2.index, hardswish_dup.index)
                last_conv2.replace_input_shape(slice2.name, conv_dup_output)
                last_conv2.replace_input_layer(slice2.name, hardswish_dup.name)

                self._graph.add_edge(avgpool2, conv_dup)
                self._graph.add_edge(conv_dup, hardswish_dup)
                self._graph.add_edge(hardswish_dup, last_conv2)

        # finalize graph manipulation outside graph iteration
        for layer in layers_to_remove:
            self._graph.remove_layer(layer)

        for layer in new_layers:
            self._graph.relax_new_layer_into_graph(layer)

    def _handle_hc_transpose(self):
        """
        this function is used to split hc transpose layer to 3 supported transpose (format conversion) layers
        transpose_width_features -> transpose_height_width -> transpose_width_features
        """
        new_layers = []
        layers_to_remove = []
        for layer in list(self._graph):
            if layer.op == LayerType.transpose and layer.perm == [0, 3, 2, 1]:
                transpose1 = FormatConversionLayer()
                transpose1.conversion_type = FormatConversionType.transpose_width_features
                transpose1.index = self._graph.get_next_index()
                transpose1.name = f"transpose1_after_{layer.name}"
                transpose1.input_indices = layer.input_indices
                transpose1.inputs = layer.inputs
                new_layers.append(transpose1)

                transpose2 = FormatConversionLayer()
                transpose2.conversion_type = FormatConversionType.transpose_height_width
                transpose2.index = transpose1.index + 1
                transpose2.name = f"transpose2_after_{transpose1.name}"
                transpose2.input_indices = [transpose1.index]
                transpose2.inputs = [transpose1.name]
                new_layers.append(transpose2)

                transpose3 = FormatConversionLayer()
                transpose3.conversion_type = FormatConversionType.transpose_width_features
                transpose3.index = transpose2.index + 1
                transpose3.name = f"transpose3_after_{transpose2.name}"
                transpose3.input_indices = [transpose2.index]
                transpose3.inputs = [transpose2.name]
                new_layers.append(transpose3)

                transpose1.outputs = [transpose2.name]
                transpose1.output_indices = [transpose2.index]

                transpose2.outputs = [transpose3.name]
                transpose2.output_indices = [transpose3.index]

                transpose3.outputs = layer.outputs
                transpose3.output_indices = layer.output_indices

                pred = next(iter(self._graph.predecessors(layer)))
                pred.replace_output_layer(layer.name, transpose1.name)
                pred.replace_output_index(layer.index, transpose1.index)

                succs = list(self._graph.successors(layer))
                for succ in succs:
                    succ.replace_input_layer(layer.name, transpose3.name)
                    succ.replace_input_index(layer.index, transpose3.index)
                    self._graph.remove_edge(layer, succ)
                    self._graph.add_edge(transpose3, succ)

                transpose1.input_shapes = layer.input_shapes
                transpose1.output_shapes = [
                    transpose1.input_shape[0],
                    transpose1.input_shape[1],
                    transpose1.input_shape[3],
                    transpose1.input_shape[2],
                ]

                transpose2.input_shapes = transpose1.output_shapes
                transpose2.output_shapes = [
                    transpose2.input_shape[0],
                    transpose2.input_shape[2],
                    transpose2.input_shape[1],
                    transpose2.input_shape[3],
                ]

                transpose3.input_shapes = transpose3.output_shapes
                transpose3.output_shapes = layer.output_shapes

                self._graph.add_node(transpose1)
                self._graph.add_node(transpose2)
                self._graph.add_node(transpose3)

                self._graph.remove_edge(pred, layer)
                self._graph.add_edge(pred, transpose1)

                self._graph.add_edge(transpose1, transpose2)
                self._graph.add_edge(transpose2, transpose3)
                layers_to_remove.append(layer)

        # finalize graph manipulation outside graph iteration
        for layer in layers_to_remove:
            self._graph.remove_layer(layer)

        for layer in new_layers:
            self._graph.relax_new_layer_into_graph(layer)

    def _handle_deconv1x1s1(self):
        """
        finds deconv1x1s1 and replaces it with conv1x1s1
        """
        layers_to_remove = []
        new_layers = []
        successors_meta_data = {}
        for layer in list(self._output_graph):
            if layer.op == LayerType.deconv and (
                layer.kernel_height == layer.stride_height == layer.kernel_width == layer.stride_width == 1
            ):
                self._logger.debug(f"deconv1x1s1 {layer.name} was found. replace it with conv1x1s1")
                conv = FusedConv2DLayer()
                new_layers.append(conv)
                self._output_graph.add_node(conv)
                conv.name = f"conv_{layer.name}"
                conv.original_names = layer.original_names
                conv.op = LayerType.conv
                conv.index = self._output_graph.get_next_index()
                conv.kernel_shape = layer.kernel_shape
                conv.dilations = layer.dilations
                conv.strides = layer.strides
                conv.padding = PaddingType.valid

                conv.inputs = layer.inputs
                conv.input_indices = layer.input_indices
                conv.input_shapes = layer.input_shapes
                conv.input_vertex_order = layer.input_vertex_order

                conv.outputs = layer.outputs
                conv.output_indices = layer.output_indices
                conv.output_shapes = layer.output_shapes
                conv.move_params(layer)

                conv.bias = layer.bias
                conv.kernel = layer.kernel
                conv.pre_layer_bias = layer.pre_layer_bias
                conv.dynamic_weights = layer.dynamic_weights
                conv.spatial_flatten_output = layer.spatial_flatten_output
                conv.is_dilated_s2b = layer.is_dilated_s2b
                conv.groups = layer.groups
                conv.activation = layer.activation

                preds = list(self._output_graph.predecessors(layer))
                succs = list(self._output_graph.successors(layer))
                for succ in succs:
                    succ.replace_input_index(layer.index, conv.index)
                    succ.replace_input_layer(layer.name, conv.name)
                    succ.replace_input_shape(layer.name, conv.input_shape)
                    self._output_graph.remove_edge(layer, succ)
                    self._output_graph.add_edge(conv, succ)
                    HailoNN.update_successors_meta_data(succ, successors_meta_data)

                for pred in preds:
                    pred.replace_output_index(layer.index, conv.index)
                    pred.replace_output_layer(layer.name, conv.name)
                    self._output_graph.remove_edge(pred, layer)
                    self._output_graph.add_edge(pred, conv)

                layers_to_remove.append(layer)

        for layer_to_remove in layers_to_remove:
            self._output_graph.remove_layer(layer_to_remove)

        for layer in new_layers:
            self._output_graph.relax_new_layer_into_graph(layer, successors_meta_data)

    def _handle_null_concat(self):
        """
        Removes concat layers with single input
        """
        layers_to_remove = []
        for layer in list(self._graph):
            if layer.op == LayerType.concat:
                succs = list(self._graph.successors(layer))
                preds = list(self._graph.predecessors(layer))

                if len(preds) == 1:
                    pred = preds[0]
                    for i, succ in enumerate(succs):
                        # replaces the layer with the first successor, afterward appends the rest of the successors to
                        # the predecessor
                        if i == 0:
                            pred.replace_output_index(layer.index, succ.index)
                            pred.replace_output_layer(layer.name, succ.name)
                        else:
                            pred.append_output_index(succ.index)
                            pred.append_output_layer(succ.name)
                            pred.append_output_shape(succ.input_shape)

                        succ.replace_input_index(layer.index, pred.index)
                        succ.replace_input_layer(layer.name, pred.name)
                        succ.replace_input_shape(layer.name, pred.input_shape)

                        self._graph.remove_edge(layer, succ)
                        self._graph.add_edge(pred, succ)

                    pred.move_params(layer)
                    self._graph.remove_edge(pred, layer)
                    self._logger.debug(f"Removed 1 input concat layer {layer.name}")
                    layers_to_remove.append(layer)

        for layer_to_remove in layers_to_remove:
            self._graph.remove_layer(layer_to_remove)

    def _handle_null_slice(self):
        """
        Removes redundant slice layers where for all axes: start=0, end=input_shape[axis], step=1
        """
        layers_to_remove = []
        for layer in list(self._graph):
            if layer.op in [LayerType.slice, LayerType.base_slice] and (
                all(slice_dim[-1] == 1 for slice_dim in [layer.height_slice, layer.width_slice, layer.features_slice])
                and layer.input_shape == layer.output_shape
            ):
                self._fuser_helper.remove_layer(layer, layers_to_remove)
                self._logger.debug(f"Removed redundant slice layer {layer.name}")

        for layer_to_remove in layers_to_remove:
            self._graph.remove_layer(layer_to_remove)

    def _ends_with_flat_layers(self, layer):
        if layer.op in [LayerType.dense, LayerType.softmax, LayerType.output_layer] or (
            layer.op == LayerType.resize
            and layer.resize_method == ResizeMethod.nearest_neighbor
            and layer.input_height == layer.input_width == 1
        ):
            return True

        if layer.op == LayerType.format_conversion and layer.conversion_type == FormatConversionType.flat_to_frames:
            return all(self._ends_with_flat_layers(succ) for succ in list(self._output_graph.successors(layer)))

        if layer.op == LayerType.slice:
            return layer.input_shape[1:3] == [1, 1] and layer.height_slice == layer.width_slice == [0, 1, 1]

        if layer.op == LayerType.conv and layer.kernel_width == 1 and layer.kernel_height == 1:
            for succ in list(self._output_graph.successors(layer)):
                if succ.op == LayerType.batch_norm and len(list(self._output_graph.successors(succ))) == 1:
                    # skipping batch norm after conv1x1 because it will be folded in post fuser
                    succ = next(iter(self._output_graph.successors(succ)))
                if not self._ends_with_flat_layers(succ):
                    return False
            return True

        return False

    def _remove_format_conversion_after_flat_layer(self, flat_layer, layers_to_remove):
        succs = list(self._output_graph.successors(flat_layer))
        for succ in succs:
            if succ.op == LayerType.format_conversion and succ.conversion_type == FormatConversionType.flat_to_frames:
                reshape_succs = list(self._output_graph.successors(succ))
                all_succs_are_conv1x1 = True
                for reshape_succ in reshape_succs:
                    if (
                        reshape_succ.op != LayerType.conv
                        or reshape_succ.kernel_height != 1
                        or reshape_succ.kernel_width != 1
                    ):
                        all_succs_are_conv1x1 = False
                        break

                if all_succs_are_conv1x1:
                    flat_layer.move_params(succ)
                    self._output_graph.remove_edge(flat_layer, succ)

                    for i, reshape_succ in enumerate(reshape_succs):
                        # replaces the successor with the first successor, afterward appends the rest of
                        # the reshape layer successors to the gap layer
                        if i == 0:
                            flat_layer.replace_output_index(succ.index, reshape_succ.index)
                            flat_layer.replace_output_layer(succ.name, reshape_succ.name)
                        else:
                            flat_layer.append_output_index(reshape_succ.index)
                            flat_layer.append_output_layer(reshape_succ.name)
                            flat_layer.append_output_shape(reshape_succ.input_shape)

                        reshape_succ.replace_input_index(succ.index, flat_layer.index)
                        reshape_succ.replace_input_layer(succ.name, flat_layer.name)
                        reshape_succ.replace_input_shape(succ.name, flat_layer.output_shape)
                        self._output_graph.add_edge(flat_layer, reshape_succ)

                    layers_to_remove.append(succ)
                    self._logger.debug(
                        f"Removed flat to frames reshape layer {succ.name} that is followed by 1x1 "
                        f"conv, after global average pool layer {flat_layer.name}",
                    )

    def _switch_conv1x1_to_dense(self, conv, successors_meta_data, new_layers, layers_to_remove):
        layers_to_remove.append(conv)
        dense = FusedDenseLayer()
        self._output_graph.add_node(dense)
        dense.name = f"dense_{conv.name}"
        dense.original_names = conv.original_names
        dense.op = LayerType.dense
        dense.index = self._output_graph.get_next_index()

        dense.inputs = conv.inputs
        dense.input_indices = conv.input_indices
        dense.input_shapes = conv.input_shapes

        dense.outputs = conv.outputs
        dense.output_indices = conv.output_indices
        dense.output_shapes = [[-1, conv.output_features] for _ in conv.output_shapes]
        dense.activation = conv.activation
        dense.bn_enabled = conv.bn_enabled
        dense.pre_layer_bn = conv.pre_layer_bn
        dense.kernel_shape = [conv.input_features, conv.output_features]
        dense.should_squeeze_kernel = True
        dense.move_params(conv)

        preds = list(self._output_graph.predecessors(conv))
        succs = list(self._output_graph.successors(conv))
        for succ in succs:
            succ.replace_input_index(conv.index, dense.index)
            succ.replace_input_shape(conv.name, dense.output_shape)
            succ.replace_input_layer(conv.name, dense.name)
            self._output_graph.add_edge(dense, succ)
            self._output_graph.remove_edge(conv, succ)
            HailoNN.update_successors_meta_data(succ, successors_meta_data)
            if succ.op == LayerType.batch_norm:
                succ.output_shapes = succ.input_shapes

        for pred in preds:
            pred.replace_output_layer(conv.name, dense.name)
            pred.replace_output_index(conv.index, dense.index)
            pred.output_shapes[pred.outputs.index(dense.name)] = dense.input_shape
            self._output_graph.add_edge(pred, dense)
            self._output_graph.remove_edge(pred, conv)

        new_layers.append(dense)

        return dense

    def _handle_conv1x1_chains_after_dense_switch(self, layer, layers_to_remove, successors_meta_data, new_layers):
        for succ in list(self._output_graph.successors(layer)):
            outputs = layer.outputs
            if succ.op == LayerType.batch_norm and len(list(self._output_graph.successors(succ))) == 1:
                # skipping batch norm after conv1x1 because it will be folded in post fuser
                outputs = succ.outputs
                succ = next(iter(self._output_graph.successors(succ)))
            if succ.op == LayerType.conv and succ.kernel_height == succ.kernel_width == 1:
                dense = self._switch_conv1x1_to_dense(succ, successors_meta_data, new_layers, layers_to_remove)
                self._handle_conv1x1_chains_after_dense_switch(
                    dense,
                    layers_to_remove,
                    successors_meta_data,
                    new_layers,
                )
            else:
                succ.input_shapes = [layer.output_shapes[outputs.index(succ.name)]]

    def _handle_conv1x1_after_global_avgpool(self):
        layers_to_remove = []
        new_layers = []
        successors_meta_data = {}

        for layer in list(self._output_graph):
            if isinstance(layer, PoolingLayer) and layer.is_global_avg_pool() or isinstance(layer, FusedDenseLayer):
                # checked if the current avgpool is in fact a global avgpool layer
                self._logger.debug(
                    f"Found global avgpool layer, input shape = {layer.input_shape} "
                    f"kernel shape = {layer.kernel_shape}",
                )

                ends_with_flat_layer = True
                for succ in list(self._output_graph.successors(layer)):
                    if not self._ends_with_flat_layers(succ):
                        ends_with_flat_layer = False
                        break

                if ends_with_flat_layer:
                    self._remove_format_conversion_after_flat_layer(layer, layers_to_remove)
                    # finds successor layer of type conv and 1x1 kernel shape - and switches op to dense (and reshape
                    # kernel)
                    succs_after_remove_format_conversion = list(self._output_graph.successors(layer))
                    for succ in succs_after_remove_format_conversion:
                        if succ.op == LayerType.conv and succ.kernel_height == succ.kernel_width == succ.groups == 1:
                            dense = self._switch_conv1x1_to_dense(
                                succ,
                                successors_meta_data,
                                new_layers,
                                layers_to_remove,
                            )

                            self._handle_conv1x1_chains_after_dense_switch(
                                dense,
                                layers_to_remove,
                                successors_meta_data,
                                new_layers,
                            )

        for layer_to_remove in layers_to_remove:
            self._output_graph.remove_layer(layer_to_remove)

        for layer in new_layers:
            self._output_graph.relax_new_layer_into_graph(layer, successors_meta_data)

    def _handle_spatial_flatten_before_dynamic_weights_layer(self):
        for layer in list(self._graph):
            if not layer.dynamic_weights:
                continue

            pred = self._graph.get_layer_by_name(layer.inputs[1])
            if pred.op != LayerType.format_conversion:
                continue
            format_conversion = pred

            if not format_conversion.is_flatten_reshape:
                continue

            format_conversion_preds = list(self._graph.predecessors(format_conversion))
            if len(format_conversion_preds) != 1:
                continue

            format_conversion_succs = list(self._graph.successors(format_conversion))
            if len(format_conversion_succs) != 1:
                continue

            format_conversion_pred = format_conversion_preds[0]
            if format_conversion_pred.op != LayerType.base_conv:
                continue
            conv = format_conversion_pred

            if conv.kernel_height != 1 or conv.kernel_width != 1:
                continue

            conv_preds = list(self._graph.predecessors(conv))
            if len(conv_preds) != 1:
                continue
            conv_pred = conv_preds[0]

            conv_succs = list(self._graph.successors(conv))
            if len(conv_succs) != 1:
                continue

            # replace format conversion with conv
            conv_pred.replace_output_layer(conv.name, format_conversion.name)
            conv_pred.replace_output_index(conv.index, format_conversion.index)

            format_conversion.replace_input_layer(conv.name, conv_pred.name)
            format_conversion.replace_input_index(conv.index, conv_pred.index)
            format_conversion.replace_input_shape(conv_pred.name, conv_pred.output_shape)
            format_conversion.replace_output_layer(layer.name, conv.name)
            format_conversion.replace_input_index(layer.index, conv.index)
            input_shape = format_conversion.input_shape
            format_conversion.output_shapes = [-1, 1, input_shape[1] * input_shape[2], input_shape[3]]

            conv.replace_input_layer(conv_pred.name, format_conversion.name)
            conv.replace_input_index(conv_pred.index, format_conversion.index)
            conv.replace_input_shape(format_conversion.name, format_conversion.output_shape)
            conv.replace_output_layer(format_conversion.name, layer.name)
            conv.replace_output_index(format_conversion.index, layer.index)
            conv.replace_output_shape(layer.index, layer.index)
            conv.output_shapes = [[*conv.input_shape[:-1], conv.output_features]]

            layer.replace_input_layer(format_conversion.name, conv.name)
            layer.replace_input_index(format_conversion.index, conv.index)

            self._graph.remove_edge(conv_pred, conv)
            self._graph.remove_edge(conv, format_conversion)
            self._graph.remove_edge(format_conversion, layer)
            self._graph.add_edge(conv_pred, format_conversion)
            self._graph.add_edge(format_conversion, conv)
            self._graph.add_edge(conv, layer)

    def _set_avgpool_attrs(self, mean, kernel_h, kernel_w):
        mean.op = LayerType.avgpool
        mean.padding = PaddingType.valid
        mean.kernel_shape = [1, kernel_h, kernel_w, 1]
        mean.strides = [1, kernel_h, kernel_w, 1]

    def _set_inst_norm_square_attrs(self, square):
        square.feature_multiplier_type = FeatureMultiplierType.square

    def _set_inst_norm_conv_attrs(self, conv, groups, kernel):
        conv.groups = groups
        conv.kernel = kernel
        conv.kernel_shape = [kernel.shape[0], kernel.shape[1], kernel.shape[2] * groups, kernel.shape[3]]
        conv.strides = [1, 1, 1, 1]
        conv.dilations = [1, 1, 1, 1]
        conv.padding = PaddingType.valid
        conv.group_sizes = [1] * groups

    def _split_log_softmax_layers(self):
        """
        Split log softmax layers according to the formula:
        log_softmax(x) = log_softmax(x - max(x)) = x - max(x) - log(sum(exp(x - max(x))))
        This decomposition is used to avoid exp of large values and log of values close to 0 for numerical stability.
        """
        new_layers = []
        layers_to_remove = []
        successors_meta_data = {}

        for layer in list(self._graph):
            if layer.op != LayerType.log_softmax:
                continue

            input_shape = layer.input_shape.copy()
            pred = next(iter(self._graph.predecessors(layer)))
            succs = list(self._graph.successors(layer))
            block_info = (BlockType.LOG_SOFTMAX, layer.name)
            flat_to_frames = None
            next_index = self._graph.get_next_index()

            if len(input_shape) == 2:
                # logsoftmax of rank2, adding flat_to_frames to work over rank4
                input_shape = [input_shape[0], 1, 1, input_shape[1]]
                flat_to_frames = self._fuser_helper.create_layer(
                    FormatConversionLayer,
                    next_index,
                    "flat_to_frames1",
                    layer,
                    new_layers,
                    [input_shape],
                    block_info,
                )
                flat_to_frames.conversion_type = FormatConversionType.flat_to_frames
                next_index = flat_to_frames.index + 1

            reduced_shape = [x if i != layer.axis else 1 for i, x in enumerate(input_shape)]
            reduce_max = self._fuser_helper.create_layer(
                ReduceMaxLayer,
                next_index,
                "reduce_max1",
                layer,
                new_layers,
                reduced_shape,
                block_info,
            )
            reduce_max.reduce_axes = [layer.axis]
            ew_sub_of_max = self._fuser_helper.create_layer(
                EWSubLayer,
                reduce_max.index + 1,
                "ew_sub1",
                layer,
                new_layers,
                [input_shape],
                block_info,
            )
            reduce_sum = self._fuser_helper.create_layer(
                ReduceSumLayer,
                ew_sub_of_max.index + 1,
                "reduce_sum1",
                layer,
                new_layers,
                reduced_shape,
                block_info,
            )
            reduce_sum.reduce_axes = [layer.axis]
            log_activation = self._fuser_helper.create_layer(
                ActivationLayer,
                reduce_sum.index + 1,
                "activation1",
                layer,
                new_layers,
                reduced_shape,
                block_info,
            )
            log_activation.activation = ActivationType.log
            exp_activation = self._fuser_helper.create_layer(
                ActivationLayer,
                log_activation.index + 1,
                "activation2",
                layer,
                new_layers,
                [input_shape],
                block_info,
            )
            exp_activation.activation = ActivationType.exp
            ew_sub = self._fuser_helper.create_layer(
                EWSubLayer,
                exp_activation.index + 1,
                "ew_sub2",
                layer,
                new_layers,
                [input_shape],
                block_info,
            )

            layers_to_remove.append(layer)
            self._graph.remove_edge(pred, layer)
            for succ in succs:
                self._graph.remove_edge(layer, succ)

            layer_to_inputs = {
                reduce_max: [pred],
                ew_sub_of_max: [pred, reduce_max],
                exp_activation: [ew_sub_of_max],
                reduce_sum: [exp_activation],
                log_activation: [reduce_sum],
                ew_sub: [ew_sub_of_max, log_activation],
            }
            layer_to_inputs.update(
                {succ: [ew_sub] for succ in succs},
            )

            layer_to_outputs = {
                pred: [ew_sub_of_max, reduce_max],
                reduce_max: [ew_sub_of_max],
                log_activation: [ew_sub],
                exp_activation: [reduce_sum],
                ew_sub_of_max: [exp_activation, ew_sub],
                reduce_sum: [log_activation],
                ew_sub: succs,
            }

            if flat_to_frames is not None:
                layer_to_inputs[flat_to_frames] = [pred]
                layer_to_inputs[reduce_max] = [flat_to_frames]
                layer_to_inputs[ew_sub_of_max] = [flat_to_frames, reduce_max]
                layer_to_outputs[pred] = [flat_to_frames]
                layer_to_outputs[flat_to_frames] = [ew_sub_of_max, reduce_max]

            self._fuser_helper.handle_new_preds_succs(layer_to_inputs, layer_to_outputs)
            self._logger.debug(
                f"Replaced {layer.full_name_msg} with layers that run on LCU",
            )

        # finalize graph manipulation outside graph iteration
        for layer in layers_to_remove:
            self._graph.remove_layer(layer)

        for layer in new_layers:
            self._graph.relax_new_layer_into_graph(layer, successors_meta_data)

    def _split_one_hot_layers(self):
        """
        Split one hot layers to conv 1x1 with delta activation
        """
        new_layers = []
        layers_to_remove = []

        for layer in list(self._graph):
            if layer.op == LayerType.one_hot:
                layers_to_remove.append(layer)

                conv = Conv2DLayer()
                conv.name = f"conv_1x1_{layer.name}"
                conv.index = self._graph.get_next_index()
                conv.kernel = np.ones((1, 1, 1, layer.num_classes))
                conv.kernel_shape = [1, 1, 1, layer.num_classes]
                conv.bias = np.array(np.arange(layer.num_classes)) * -1
                conv.dilations = [1, 1, 1, 1]
                conv.strides = [1, 1, 1, 1]
                conv.padding = PaddingType.valid
                conv.inputs = layer.inputs
                conv.input_indices = layer.input_indices
                conv.input_shapes = layer.input_shapes.copy()
                conv.output_shapes = [layer.output_shape.copy()]
                conv.move_params(layer)
                conv.block_info = (BlockType.ONE_HOT, layer.name)
                self._graph.add_node(conv)
                new_layers.append(conv)

                delta = ActivationLayer()
                delta.index = conv.index + 1
                delta.name = f"base_activation_{layer.name}"
                delta.inputs = [conv.name]
                delta.input_indices = [conv.index]
                delta.input_shapes = conv.output_shapes.copy()
                delta.output_shapes = layer.output_shapes.copy()
                delta.activation = ActivationType.delta
                delta.activation_delta_bias = 0
                delta.outputs = layer.outputs.copy()
                delta.output_indices = layer.output_indices.copy()
                delta.move_params(layer)
                delta.block_info = (BlockType.ONE_HOT, layer.name)
                self._graph.add_node(delta)
                self._graph.add_edge(conv, delta)
                new_layers.append(delta)

                conv.outputs = [delta.name]
                conv.output_indices = [delta.index]

                for pred in list(self._graph.predecessors(layer)):
                    self._graph.remove_edge(pred, layer)
                    self._graph.add_edge(pred, conv)
                    pred.replace_output_layer(layer.name, conv.name)
                    pred.replace_output_index(layer.index, conv.index)

                for succ in list(self._graph.successors(layer)):
                    self._graph.remove_edge(layer, succ)
                    self._graph.add_edge(delta, succ)
                    succ.replace_input_layer(layer.name, delta.name)
                    succ.replace_input_index(layer.index, delta.index)

                self._logger.debug(f"Splitted one hot layer at {layer.name}.")

        # finalize graph manipulation outside graph iteration
        for layer in layers_to_remove:
            self._graph.remove_layer(layer)

        for layer in new_layers:
            self._graph.relax_new_layer_into_graph(layer)

    def _split_scatter_nd_layers(self):
        r"""
        replaces the block of scatter nd to the form in right:
                    +-------+    +-------+             +-------+  +---------+   +-------+  +---------+
                    | DATA1 |    | DATA2 |             | DATA1 |  | MASK 0S |   | DATA2 |  | MASK 1S |
                    +-----+-+    +-+-----+             +-------+  +---------+   +-------+  +---------+
                          |        |                          \       /             \       /
        +----------+      |--------+----+                      \     /               \     /
        | INDICES  |------+  SCATTER ND |       ==              \   /                 \   /
        +----------+      +-----+-------+                      +-------+            +-------+
                                |                              |EW MULT|            |EW MULT|
                                |                              +-------+            +-------+
                                                                        \          /
                                                                         \        /
                                                                        +----------+
                                                                        |  EW ADD  |
                                                                        +----------+
        """

        layers_to_remove = []

        def create_const_inputs(indices_values):
            indices_values_orig_shape = indices_values.shape
            data_shape = indices_values_orig_shape[:-1]

            mask = np.ones(data_shape)
            for idx in np.ndindex(data_shape):
                actual_indices = indices_values[idx]
                mask[tuple(actual_indices)] = 0
            mask = np.transpose(mask, (0, 2, 1, 3))
            mask = np.reshape(mask, (1, 1, mask.shape[1], -1))
            return mask, 1 - mask  # in case of updates indices, create the complement of the ones mask

        for layer in list(self._graph):
            if layer.op == LayerType.scatter_nd:
                data_const, update_const = create_const_inputs(layer.indices)
                base_index = self._graph.get_next_index()
                data_const_input = ConstInputLayer()
                data_const_input.name = f"{layer.name}_data_const"
                data_const_input.index = base_index
                data_const_input.original_names = layer.original_names.copy()
                data_const_input.input_shapes = [[-1, *data_const.shape[1:]]]
                data_const_input.output_shapes = data_const_input.input_shapes.copy()
                data_const_input.const_values = data_const[0]
                base_index += 1

                updates_const_input = ConstInputLayer()
                updates_const_input.name = f"{layer.name}_updates_const"
                updates_const_input.index = base_index
                updates_const_input.original_names = layer.original_names.copy()
                updates_const_input.input_shapes = [[-1, *update_const.shape[1:]]]
                updates_const_input.output_shapes = updates_const_input.input_shapes.copy()
                updates_const_input.const_values = update_const[0]
                base_index += 1

                data_ew_mult = EWMultLayer()
                data_ew_mult.index = base_index
                data_ew_mult.name = f"ew_mult_data_{layer.name}"
                data_ew_mult.original_names = layer.original_names.copy()
                data_ew_mult.inputs = [layer.inputs[0], data_const_input.name]
                data_ew_mult.input_shapes = [layer.input_shapes[0], data_const_input.input_shapes[0]]
                data_ew_mult.output_shapes = [layer.input_shapes[0]]
                base_index += 1

                updates_ew_mult = EWMultLayer()
                updates_ew_mult.index = base_index
                updates_ew_mult.name = f"ew_mult_updates_{layer.name}"
                updates_ew_mult.original_names = layer.original_names.copy()
                updates_ew_mult.inputs = [layer.inputs[1], updates_const_input.name]
                updates_ew_mult.input_shapes = [layer.input_shapes[1], updates_const_input.input_shapes[0]]
                updates_ew_mult.output_shapes = [layer.input_shapes[1]]
                base_index += 1

                scatter_nd_ew_add = EWAddLayer()
                scatter_nd_ew_add.index = base_index
                scatter_nd_ew_add.name = f"ew_add_{layer.name}"
                scatter_nd_ew_add.original_names = layer.original_names.copy()
                scatter_nd_ew_add.inputs = [data_ew_mult.name, updates_ew_mult.name]
                scatter_nd_ew_add.input_shapes = [data_ew_mult.output_shapes[0], updates_ew_mult.output_shapes[0]]
                scatter_nd_ew_add.output_shapes = [layer.output_shapes[0]]
                base_index += 1

                # connects layers to graph
                preds = list(self._graph.predecessors(layer))  # two preds for scatter_nd
                self._fuser_helper.replace_succ(preds[0], layer, data_ew_mult)
                self._fuser_helper.add_succs(data_const_input, [data_ew_mult], update_output_shapes=False)
                self._fuser_helper.replace_succ(preds[1], layer, updates_ew_mult)
                self._fuser_helper.add_succs(updates_const_input, [updates_ew_mult], update_output_shapes=False)
                self._fuser_helper.add_preds(data_ew_mult, [preds[0], data_const_input], update_input_shapes=False)
                self._fuser_helper.add_preds(
                    updates_ew_mult, [preds[1], updates_const_input], update_input_shapes=False
                )
                self._fuser_helper.add_succs(data_ew_mult, [scatter_nd_ew_add], update_output_shapes=False)
                self._fuser_helper.add_succs(updates_ew_mult, [scatter_nd_ew_add], update_output_shapes=False)
                self._fuser_helper.add_preds(
                    scatter_nd_ew_add, [data_ew_mult, updates_ew_mult], update_input_shapes=False
                )

                succs = list(self._graph.successors(layer))
                for succ in succs:
                    self._fuser_helper.replace_pred(succ, layer, scatter_nd_ew_add)
                    self._fuser_helper.add_succs(scatter_nd_ew_add, [succ], update_output_shapes=False)

                layers_to_remove.append(layer)
                self._logger.debug(f"Splitted scatter_nd layer at {layer.name}.")

        for layer in layers_to_remove:
            self._graph.remove_layer(layer)

    def _handle_einsum_layers(self):
        layers_to_remove = []
        for layer in list(self._graph):
            if layer.op == LayerType.einsum:
                if layer.equation in ["byhwc,hkc->byhwk", "byhwc,wkc->byhwk"]:
                    """
                    replace the einsum with matmul operation:
                    for equation "byhwc,hkc->byhwk":
                        1X56X56X96            1X56X14X96
                    +----------------+   +----------------+
                    |     INPUT      |   |    WEIGHTS     |
                    +-------+--------+   +--------+-------+
                            |                     |
                            |            +--------+-------+
                            |            |   TRANSPOSE    |
                            |            +--------+-------+
                            |                     |
                            |                 1X56X96X14
                            |                     |
                    +-------+---------------------+------+
                    |      MATMUL (WINDOWS=[56, 1, 1])   |
                    +-----------------+------------------+
                                      |
                                      |
                                  1X56X56X14
                    for equation "byhwc,wkc->byhwk", adds transpose width->height after input and after matmul,
                    and changes the windows to be on height.
                    """
                    weights = layer.weights.transpose((0, 2, 1))
                    base_index = self._graph.get_next_index()
                    const_input = ConstInputLayer()
                    const_input.name = f"{layer.name}_const_input"
                    const_input.index = base_index
                    const_input.original_names = layer.original_names.copy()
                    const_input.input_shapes = [[-1, *weights.shape]]
                    const_input.output_shapes = const_input.input_shapes.copy()
                    const_input.const_values = weights
                    base_index += 1
                    matmul = MatmulLayer()
                    matmul.name = f"{layer.name}_matmul"
                    matmul.index = base_index
                    matmul.original_names = layer.original_names.copy()
                    matmul.input_shapes = [layer.input_shape, const_input.input_shape]
                    matmul.output_shapes = [*layer.input_shapes[0][:-1], weights.shape[-1]]
                    matmul.input_windows = [layer.input_shape[1], 1, 1]
                    matmul.groups = layer.input_shape[-1] // layer.weights.shape[-1]
                    matmul.input_tiles = [[1, 1, 1], [1, 1, layer.input_shape[-1] // layer.weights.shape[-1]]]
                    matmul.transpose_matmul_input = False

                    # connects the new layers
                    preds = list(self._graph.predecessors(layer))
                    succs = list(self._graph.successors(layer))

                    self._fuser_helper.add_succs(const_input, [matmul], update_output_shapes=False)
                    self._fuser_helper.add_preds(matmul, [preds[0], const_input], update_input_shapes=False)
                    self._fuser_helper.replace_succ(preds[0], layer, matmul)
                    self._fuser_helper.add_succs(matmul, succs, update_output_shapes=False)
                    self._fuser_helper.replace_pred(succs[0], layer, matmul)

                    if layer.equation == "byhwc,wkc->byhwk":
                        # adds transpose layers after input and after matmul
                        self.add_format_conversion_successor(
                            preds[0],
                            matmul,
                            FormatConversionType.transpose_height_width,
                        )
                        self.add_format_conversion_successor(
                            matmul,
                            succs[0],
                            FormatConversionType.transpose_height_width,
                        )
                        matmul.input_windows = [layer.input_shape[2], 1, 1]
                elif layer.equation in ["bmchw,bnmc->bmhwn", "bchw,cj->bjhw", "nkctv,kvw->nctw"]:
                    # conv equivalent
                    conv = Conv2DLayer()
                    conv.name = f"{layer.name}_conv"
                    conv.index = self._graph.get_next_index()
                    conv.original_names = layer.original_names.copy()
                    conv.input_shapes = layer.input_shapes.copy()
                    conv.output_shapes = layer.output_shapes.copy()
                    conv.padding = PaddingType.valid
                    conv.strides = [1, 1, 1, 1]
                    conv.dilations = [1, 1, 1, 1]
                    conv.input_shapes = layer.input_shapes.copy()
                    conv.output_shapes = layer.output_shapes.copy()
                    kernel = layer.weights

                    if layer.equation == "bmchw,bnmc->bmhwn":
                        # acts like group
                        groups = kernel.shape[2]
                        kernel = np.transpose(kernel, [0, 3, 2, 1])
                        kernel = kernel.reshape(1, 1, kernel.shape[1], kernel.shape[-1] * groups)

                    elif layer.equation in ["bchw,cj->bjhw", "nkctv,kvw->nctw"]:
                        # acts like conv1x1
                        groups = 1
                        kernel = (
                            kernel.reshape(1, 1, kernel.shape[0], kernel.shape[1])
                            if len(kernel.shape) == 2
                            else kernel.reshape(1, 1, kernel.shape[0] * kernel.shape[1], kernel.shape[2])
                        )

                        if layer.equation == "nkctv,kvw->nctw":
                            # gcn equivalent
                            conv.transpose_output_width_features = True

                    conv.kernel = kernel
                    conv.kernel_shape = kernel.shape
                    conv.groups = groups

                    # connects the new layer
                    preds = list(self._graph.predecessors(layer))
                    succs = list(self._graph.successors(layer))

                    self._fuser_helper.add_succs(conv, succs, update_output_shapes=False)
                    self._fuser_helper.add_preds(conv, preds, update_input_shapes=False)

                    for succ in succs:
                        self._fuser_helper.replace_pred(succ, layer, conv)
                    for pred in preds:
                        self._fuser_helper.replace_succ(pred, layer, conv)
                else:
                    raise BackendFuserException(f"Unsupported einsum equation {layer.equation}")
                layers_to_remove.append(layer)
        for layer in layers_to_remove:
            self._graph.remove_layer(layer)

    def _handle_layer_norm(self):
        """
        Separate layer norm beta and gamma to normalization layers and split layer norm layer in case of spatial reduce.
        """
        new_layers = []
        layers_to_remove = []
        successors_meta_data = {}

        for layer in list(self._graph):
            if layer.op == LayerType.layer_normalization:
                self._split_layer_norm_beta_gamma(new_layers, successors_meta_data, layer)
                if not (layer.axes == [3] or layer.axes == [1, 2, 3]):
                    LayerNormMapping(self._graph, {}, {}, None).split_single_layer_norm(
                        new_layers,
                        layers_to_remove,
                        layer,
                    )

        for layer in layers_to_remove:
            self._graph.remove_layer(layer)

        for layer in new_layers:
            self._graph.relax_new_layer_into_graph(layer, successors_meta_data)

    def _split_layer_norm_beta_gamma(self, new_layers, successors_meta_data, layer):
        index = self._graph.get_next_index()
        beta = layer.beta.flatten()
        if len(beta) > layer.output_features:
            scalar = beta[0]
            if np.any(beta != scalar):
                raise BackendFuserException(
                    f"In {layer.name} given beta argument is over all dimensions which is not supported yet",
                )

            beta = np.array([scalar])

        gamma = layer.gamma.flatten()
        if len(gamma) > layer.output_features:
            scalar = gamma[0]
            if np.any(gamma != scalar):
                raise BackendFuserException(
                    f"In {layer.name} given gamma argument is over all dimensions which is not supported yet",
                )

            gamma = np.array([scalar])

        if beta[0] == 0 and gamma[0] == 1:
            return

        norm_gamma_beta = NormalizationLayer()
        norm_gamma_beta.mean = np.negative(beta) / gamma
        norm_gamma_beta.std = 1 / gamma
        norm_gamma_beta.index = index
        norm_gamma_beta.name = f"{layer.name}_beta_gamma"
        norm_gamma_beta.move_params(layer)
        self._graph.add_node(norm_gamma_beta)
        new_layers.append(norm_gamma_beta)

        norm_gamma_beta.append_input_layer(layer.name)
        norm_gamma_beta.append_input_index(layer.index)
        norm_gamma_beta.input_shapes = [layer.input_shape]
        norm_gamma_beta.output_shapes = layer.output_shapes.copy()

        layer.outputs = [norm_gamma_beta.name]
        layer.output_indices = [norm_gamma_beta.index]
        layer.beta = 0
        layer.gamma = 1
        for succ in list(self._graph.successors(layer)):
            self._graph.remove_edge(layer, succ)
            self._graph.add_edge(norm_gamma_beta, succ)
            succ.replace_input_index(layer.index, norm_gamma_beta.index)
            succ.replace_input_layer(layer.name, norm_gamma_beta.name)
            HailoNN.update_successors_meta_data(succ, successors_meta_data)
            norm_gamma_beta.append_output_layer(succ.name)
            norm_gamma_beta.append_output_index(succ.index)
        self._graph.add_edge(layer, norm_gamma_beta)

    def _handle_null_resizes(self):
        """Find resizes with ratios == 1 and remove them."""
        layers_to_remove = []

        for layer in list(self._graph):
            if layer.op == LayerType.resize:
                resize_layer = layer
                if (
                    len(resize_layer.h_ratios) == len(resize_layer.w_ratios) == len(resize_layer.f_ratios) == 1
                    and resize_layer.h_ratios[0] == resize_layer.w_ratios[0] == resize_layer.f_ratios[0] == 1.0
                ):
                    # Found null resize layer (ratio is 1.0 for all dims)
                    # remove the null resize layer and update shapes and connections
                    self._fuser_helper.remove_layer(resize_layer, layers_to_remove)
                    self._logger.debug(f"Removed null resize layer {resize_layer.name}.")

        for resize_layer in layers_to_remove:
            self._graph.remove_layer(resize_layer)

    def _handle_null_normalizations(self):
        """Find normalizations with mean = 0 and std = 1 and remove them."""
        layers_to_remove = []

        for layer in list(self._graph):
            if layer.op == LayerType.normalization:
                norm_layer = layer
                if (
                    all(m == 0 for m in norm_layer.mean)
                    and all(s == 1 for s in norm_layer.std)
                    and layer.activation == ActivationType.linear
                ):
                    # remove the null norm layer and update shapes and connections
                    self._fuser_helper.remove_layer(norm_layer, layers_to_remove)
                    self._logger.debug(f"Removed null norm layer {norm_layer.name}.")

        for norm_layer in layers_to_remove:
            self._graph.remove_layer(norm_layer)

    def _handle_null_ew_const_input(self):
        """
        Find the following ew cases with const input and remove them:
        ew_add with const_input == 0
        ew_mul with const_input == 1
        ew_div with const_input == 1
        ew_sub with const_input == 0
        """
        layers_to_remove = []
        ew_op_to_const_input = {
            LayerType.base_ew_add: 0,
            LayerType.base_ew_sub: 0,
            LayerType.ew_mult: 1,
            LayerType.ew_div: 1,
        }

        for layer in list(self._graph):
            if layer.op not in ew_op_to_const_input:
                continue
            if layer.transpose_output_width_features or layer.spatial_flatten_output:
                continue
            const_preds = [pred for pred in self._graph.predecessors(layer) if pred.op == LayerType.const_input]
            if len(const_preds) != 1:  # regular ew op with no const input and add_n cases
                continue
            const_pred = const_preds[0]
            if not np.all(const_pred.const_values == ew_op_to_const_input[layer.op]):
                continue
            if len(list(self._graph.successors(const_pred))) == 1:
                layers_to_remove.append(const_pred)
            self._graph.remove_edge(const_pred, layer)
            self._fuser_helper.remove_layer(layer, layers_to_remove)
            self._logger.debug(f"Removed {layer.op} {layer.name} and const_input {const_pred.name}.")

        for layer in layers_to_remove:
            self._graph.remove_layer(layer)

    def _handle_null_format_conversion(self):
        """
        Find format conversion of the types - spatial flatten, spatial unflatten and flat to frames layer and
        remove them in case of:
        1.They don't change the tensor order or shape
        2.Before global pooling layer
        """
        layers_to_remove = []

        for layer in list(self._graph):
            if layer.op != LayerType.format_conversion:
                continue

            if layer.conversion_type not in [FormatConversionType.flat_to_frames, FormatConversionType.spatial_reshape]:
                continue

            succs = list(self._graph.successors(layer))
            if (len(layer.input_shape) == 4 and layer.input_shape == layer.output_shape) or (
                all(
                    isinstance(succ, PoolingLayer) and (succ.is_global_max_pool() or succ.is_global_avg_pool())
                    for succ in succs
                )
            ):
                for succ in succs:
                    if isinstance(succ, PoolingLayer) and (succ.is_global_max_pool() or succ.is_global_avg_pool()):
                        succ.set_input_shapes(layer.input_shapes)
                self._fuser_helper.remove_layer(layer, layers_to_remove)
                self._logger.debug(f"Removed null {layer.full_name_msg}.")

        for layer in layers_to_remove:
            self._graph.remove_layer(layer)

    def _handle_conv16x16s16(self):
        """Convert conv 16x16s16 to s2d + conv1x1"""
        new_layers = []
        successors_meta_data = {}

        for layer in list(self._graph):
            if (
                layer.op == LayerType.base_conv
                and layer.kernel_width == layer.kernel_height == layer.stride_height == layer.stride_width == 16
                and layer.input_height % 16 == layer.input_width % 16 == 0
            ):
                preds = list(self._graph.predecessors(layer))

                if len(preds) > 1:
                    continue

                self._logger.debug(
                    f"Found Conv16x16s16 {layer.name} and replaced it with Space to Depth and Conv1x1 layer",
                )

                pred = preds[0]
                s2d = SpaceToDepthLayer()
                s2d.index = self._graph.get_next_index()
                s2d.name = f"s2d_{layer.name}"
                self._graph.add_node(s2d)
                s2d.append_input_index(pred.index)
                s2d.append_output_index(layer.index)
                s2d.input_shape = layer.input_shape
                s2d.output_shapes = [
                    [-1, layer.input_height // 16, layer.input_width // 16, layer.input_features * 256],
                ]
                s2d.block_sizes = [16, 16]
                s2d.space_to_depth_type = SpaceToDepthType.serial
                s2d.move_params(layer)
                layer.input_shape = s2d.output_shape
                self._graph.remove_edge(pred, layer)
                self._graph.add_edge(pred, s2d)
                self._graph.add_edge(s2d, layer)
                layer.replace_input_layer(pred.name, s2d.name)
                layer.replace_input_index(pred.index, s2d.index)
                layer.kernel_shape = [1, 1, layer.input_features, layer.output_features]
                layer.strides = [1, 1, 1, 1]
                pred.replace_output_layer(layer.name, s2d.name)
                pred.replace_output_index(layer.index, s2d.index)
                layer.reshape_kernel_conv_octxoct()
                new_layers.append(s2d)

                HailoNN.update_successors_meta_data(layer, successors_meta_data)
        for layer in new_layers:
            self._graph.relax_new_layer_into_graph(layer, successors_meta_data)

    def _handle_flat_to_frames_before_resizes(self):
        """
        find resize layers with flat to frames reshape predecessor, and remove it to have rank 2 input
        """
        layers_to_remove = []

        for layer in list(self._output_graph):
            if (
                layer.op == LayerType.resize
                and layer.resize_method == ResizeMethod.nearest_neighbor
                and len(layer.input_shape) > 2
                and layer.input_height == layer.input_width == 1
            ):
                resize_pred = next(iter(self._output_graph.predecessors(layer)))
                if (
                    resize_pred.op == LayerType.format_conversion
                    and resize_pred.conversion_type == FormatConversionType.flat_to_frames
                ):
                    self._fuser_helper.remove_layer(resize_pred, layers_to_remove)
                    self._logger.debug(
                        f"Removed flat to frames reshape layer {resize_pred.name} that is followed by a "
                        f"resize layer {layer.name}",
                    )

        for layer in layers_to_remove:
            self._output_graph.remove_layer(layer)

    def _update_resize_layers_methods(self):
        for layer in list(self._graph):
            if layer.op != LayerType.resize:
                continue

            if layer.resize_method != ResizeMethod.bilinear:
                layer.set_compilation_params(resize_bilinear_streaming=ResizeBilinearStreamingPolicy.disabled)
            elif layer.resize_method == ResizeMethod.bilinear and layer.input_height == layer.input_width == 1:
                # For resize bilinear of the form (1x1)->(MxN) always prefer nearest neighbor method
                layer.resize_method = ResizeMethod.nearest_neighbor
                layer.set_compilation_params(resize_bilinear_streaming=ResizeBilinearStreamingPolicy.disabled)

    def _split_resize_layers(self):
        for layer in list(self._graph):
            # This function deals with resize-layers only
            if layer.op != LayerType.resize:
                continue
            # 1. All resize-layers are initialized with hw_type LCU.
            # 2. Nearest or Bilinear (half pixels, disabled) should remain LCU.
            # 3. Bilinear (align_corners) should change to PPU, and split ratios when needed.
            if (
                layer.op != LayerType.resize
                or layer.resize_method == ResizeMethod.nearest_neighbor
                or not layer.is_bilinear_align_corners
            ):
                layer.set_compilation_params(hw_layer_type_list=["lcu"])
                continue

            old_h_ratio = layer.h_ratios
            old_w_ratio = layer.w_ratios

            if not len(old_h_ratio) == len(old_w_ratio) == 1:
                raise BackendFuserException(
                    f'Failed to split resize layer {layer.name} due to bad input ratios. '
                    f'(translated from [{", ".join(layer.original_names)}]).',
                )

            split_h_ratio = self._calculate_ratios(old_h_ratio[0])
            split_w_ratio = self._calculate_ratios(old_w_ratio[0])

            hw_type = "ppu"
            split_size = max(len(split_h_ratio), len(split_w_ratio))
            split_h_ratio += [self.RESIZE_RATIO_NEUTRAL_VALUE] * (split_size - len(split_h_ratio))
            split_w_ratio += [self.RESIZE_RATIO_NEUTRAL_VALUE] * (split_size - len(split_w_ratio))

            # No need to check for streaming bilinear heuristic if method is nearest neighbor,
            # or in/out height/width are 1, since it is not supported with streaming bilinear
            if layer.input_width == 1 or layer.input_height == 1 or layer.output_width == 1 or layer.output_height == 1:
                layer.h_ratios = split_h_ratio
                layer.w_ratios = split_w_ratio
                layer.set_compilation_params(hw_layer_type_list=[hw_type] * split_size)
            else:
                split_streaming_bilinear_h_ratio = self._calculate_streaming_bilinear_ratio(
                    layer.input_height,
                    old_h_ratio[0],
                )
                split_streaming_bilinear_w_ratio = self._calculate_streaming_bilinear_ratio(
                    layer.input_width,
                    old_w_ratio[0],
                )
                streaming_bilinear_split_size = max(
                    len(split_streaming_bilinear_h_ratio),
                    len(split_streaming_bilinear_w_ratio),
                )
                split_streaming_bilinear_h_ratio += [self.RESIZE_RATIO_NEUTRAL_VALUE] * (
                    streaming_bilinear_split_size - len(split_streaming_bilinear_h_ratio)
                )
                split_streaming_bilinear_w_ratio += [self.RESIZE_RATIO_NEUTRAL_VALUE] * (
                    streaming_bilinear_split_size - len(split_streaming_bilinear_w_ratio)
                )

                bilinear_buffers_consumption = 0
                streaming_bilinear_buffers_consumption = 0
                current_input_height = layer.input_height

                for i in range(split_size):
                    # 2 frames in input + 2 rows in output + 1 row in reshape
                    bilinear_buffers_consumption += 2 * current_input_height + 2 + 1
                    current_input_height *= split_h_ratio[i]

                for i in range(streaming_bilinear_split_size):
                    # 2 rows in input + 2 * (ceil(h_ratio) + 1) in output + 1 row in line buffer + 1 row in reshape
                    streaming_bilinear_buffers_consumption += (
                        2 + 2 * ceil(split_streaming_bilinear_h_ratio[i] + 1.0) + 1 + 1
                    )

                if (
                    split_size >= streaming_bilinear_split_size
                    and streaming_bilinear_buffers_consumption < bilinear_buffers_consumption
                ):
                    layer.h_ratios = split_streaming_bilinear_h_ratio
                    layer.w_ratios = split_streaming_bilinear_w_ratio
                    layer.set_compilation_params(hw_layer_type_list=[hw_type] * streaming_bilinear_split_size)
                else:
                    layer.h_ratios = split_h_ratio
                    layer.w_ratios = split_w_ratio
                    layer.set_compilation_params(hw_layer_type_list=[hw_type] * split_size)

    def _calculate_ratios(self, goal_ratio):
        """
        The width of PPU residual, used to calculate the bilinear is MAXIMUM_RESIZE_RATIO_PER_LAYER (==2**4)
        Hence, we force splitting resizes greater than MAXIMUM_RESIZE_RATIO_PER_LAYER into a sequence of multipliers,
        whose product equals the goal_ratio. This gives us an approximation of the desired resize.
        E.g. goal_ratio=32 => [16, 2]
        This is irrelevant in three situations, and the ratio is kept unchanged:
        * When increasing the image up till MAXIMUM_RESIZE_RATIO_PER_LAYER,
            i.e. 1 < goal_ratio <= MAXIMUM_RESIZE_RATIO_PER_LAYER
        * When keeping the image in the original size, i.e. goal_ratio=1
        * When decreasing the image, i.e. goal_ratio < 1
        """
        if goal_ratio <= self.RESIZE_RATIO_NEUTRAL_VALUE:
            return [goal_ratio]

        split_ratio = []
        cur_value = self.RESIZE_RATIO_NEUTRAL_VALUE
        while cur_value < goal_ratio:
            multiplier = min(self.MAXIMUM_RESIZE_RATIO_PER_LAYER, goal_ratio / cur_value)
            cur_value *= multiplier
            split_ratio.append(multiplier)

        return split_ratio

    def _calculate_streaming_bilinear_ratio(self, input_size, goal_ratio):
        """
        In streaming bilinear, the following condition must hold:
        output_size < MAXIMUM_RESIZE_RATIO_PER_LAYER * input_size - (MAXIMUM_RESIZE_RATIO_PER_LAYER - 1)
        """
        if goal_ratio <= self.RESIZE_RATIO_NEUTRAL_VALUE:
            return [goal_ratio]

        split_ratio = []
        cur_value = self.RESIZE_RATIO_NEUTRAL_VALUE
        while cur_value < goal_ratio:
            max_ratio = self.MAXIMUM_RESIZE_RATIO_PER_LAYER - int(
                ceil((self.MAXIMUM_RESIZE_RATIO_PER_LAYER - 1) / (input_size * cur_value)),
            )
            multiplier = min(max_ratio, goal_ratio / cur_value)
            cur_value *= multiplier
            split_ratio.append(multiplier)

        return split_ratio

    def _handle_inv_sqrt_activation(self):
        """
        Fuses two base activations sqrt + inv_pos into one base activation: inv_sqrt
        """
        layers_to_remove = []
        relevant_ops = [LayerType.base_activation, LayerType.normalization, LayerType.reduce_sum]
        for layer in list(self._graph):
            if layer.op in relevant_ops and layer.activation == ActivationType.sqrt:
                succs = list(self._graph.successors(layer))
                if (
                    len(succs) == 1
                    and succs[0].op == LayerType.base_activation
                    and succs[0].activation == ActivationType.inv_pos
                ):
                    inv_pos = succs[0]
                    succs_of_inv_pos = list(self._graph.successors(inv_pos))

                    if layer.op == LayerType.base_activation:
                        for i, succ in enumerate(succs_of_inv_pos):
                            # replaces the inv_pos with its first successor, then appends the rest of the successors.
                            if i == 0:
                                layer.replace_output_index(inv_pos.index, succ.index)
                                layer.replace_output_layer(inv_pos.name, succ.name)
                            else:
                                layer.append_output_index(succ.index)
                                layer.append_output_layer(succ.name)
                                layer.append_output_shape(succ.input_shape)

                            succ.replace_input_index(inv_pos.index, layer.index)
                            succ.replace_input_layer(inv_pos.name, layer.name)
                            succ.replace_input_shape(inv_pos.name, layer.input_shape)
                            self._graph.add_edge(layer, succ)

                        self._graph.remove_edge(layer, inv_pos)
                        layers_to_remove.append(inv_pos)
                        layer.activation = ActivationType.inv_sqrt
                        layer.move_params(inv_pos)
                        self._logger.debug(f"Fused sqrt and inv pos to inv sqrt, removed {inv_pos.name}")

                    else:
                        layer.activation = ActivationType.linear
                        inv_pos.activation = ActivationType.inv_sqrt
                        self._logger.debug(
                            f"Converted {layer.op.value} layer with sqrt activation and inv_pos "
                            f"successor to {layer.op.value} layer and inv_sqrt successor at "
                            f"{layer.name}",
                        )

        for layer_to_remove in layers_to_remove:
            self._graph.remove_layer(layer_to_remove)

    def _handle_non_positive_range_matmul_transposed_input(self):
        """
        Adjust the inputs order and format of matmul layers, when the following conditions stand:
        1.input0 has non-guarenteed positive range
        2.transpose_matmul_input is False
        """
        guaranteed_positive_range_activations = [
            ActivationType.exp,
            ActivationType.hardsigmoid,
            ActivationType.hardswish,
            ActivationType.relu,
            ActivationType.relu1,
            ActivationType.relu6,
            ActivationType.relu_positive_square,
            ActivationType.sigmoid,
            ActivationType.swish,
        ]

        guaranteed_positive_range_layers = [
            LayerType.softmax,
        ]

        def is_guarenteed_positive_range_pred(layer_name):
            layer = self.graph.get_layer_by_name(layer_name)
            return any(layer.op == op for op in guaranteed_positive_range_layers) or (
                hasattr(layer, "activation")
                and any(layer.activation == act for act in guaranteed_positive_range_activations)
            )

        for layer in list(self.graph):
            if layer.op != LayerType.matmul or not (
                not is_guarenteed_positive_range_pred(layer.inputs[0])
                and not layer.transpose_matmul_input
                and is_guarenteed_positive_range_pred(layer.inputs[1])
            ):
                continue

            # switch inputs order and add a transpose layer after each input
            self._fuser_helper.modify_preds_order(layer)
            self.add_format_conversion_successor(
                self.graph.get_layer_by_name(layer.inputs[1]),
                layer,
                FormatConversionType.transpose_width_features,
            )
            self.add_format_conversion_successor(
                self.graph.get_layer_by_name(layer.inputs[0]),
                layer,
                FormatConversionType.transpose_width_features,
            )
            layer.input_windows = [1, layer.groups, 1]
            layer.kernel_shape = [1, 1, layer.kernel_shape[2], layer.kernel_shape[3]]
            layer.output_shapes = [*layer.input_shapes[0][:2], layer.input_shapes[0][2], layer.input_shapes[1][3]]
            layer.groups = 1
            layer_succ = next(iter(self.graph.successors(layer)))
            self.add_format_conversion_successor(
                layer,
                layer_succ,
                FormatConversionType.transpose_width_features,
            )

    def add_format_conversion_successor(self, layer, succ, conversion):
        format_conversion_layer = FormatConversionLayer.create(layer.name, [layer.name], conversion)
        format_conversion_layer.name = f"{layer.name}_format_conversion"
        format_conversion_layer.index = self.graph.get_next_index()
        format_conversion_layer.original_names = layer.original_names.copy()
        pred_output_shape = layer.output_shapes[0].copy()

        if conversion == FormatConversionType.transpose_width_features:
            format_conversion_layer.output_shapes = [
                [-1, pred_output_shape[1], pred_output_shape[3], pred_output_shape[2]],
            ]
        elif conversion == FormatConversionType.transpose_height_width:
            format_conversion_layer.output_shapes = [
                [-1, pred_output_shape[2], pred_output_shape[1], pred_output_shape[3]],
            ]
        else:
            raise (ValueError(f"Unsupported conversion type for {self.name} post fuser algorithm: {conversion}"))

        input_index = succ.inputs.index(layer.name)
        succ.input_shapes[input_index] = format_conversion_layer.output_shapes[0].copy()

        # inserting input_shape manually to avoid ambiguous layer output shapes as in feature splitter case
        format_conversion_layer.input_shapes = [layer.output_shapes[layer.outputs.index(succ.name)].copy()]
        self._fuser_helper.add_preds(format_conversion_layer, [layer], update_input_shapes=False)
        self._fuser_helper.replace_pred(succ, layer, format_conversion_layer)
        self._fuser_helper.replace_succ(layer, succ, format_conversion_layer)
        self._fuser_helper.add_succs(format_conversion_layer, [succ])

    def _handle_format_conversions(self):
        self._handle_general_reshape()
        self._handle_split_windowed_attention()
        self._handle_merge_windowed_attention()
        self._handle_groups_to_spatial_flatten()
        self._handle_spatial_flatten_to_groups()

    def _handle_general_reshape(self):
        successors_meta_data = {}
        new_layers = []
        for layer in list(self._graph):
            if (
                layer.op == LayerType.format_conversion
                and layer.conversion_type == FormatConversionType.general_reshape
            ):
                input_shapes = layer.input_shapes
                output_shapes = layer.output_shapes

                if len(input_shapes) == 1:
                    output_shape = output_shapes[0]
                    h, w, c = input_shapes[0][1:]
                    out_h, out_w, out_c = output_shape[1:]
                    if c / out_c == out_w and out_h == w * h:
                        # the general reshape is of the form [b, h, w, c] -> [b, h * w, x, c // x]
                        # convert the general reshape to:
                        # 1. feature splitter -> [[b, h, w, c // x] * x]
                        # 2. spatial flatten  -> [[b, 1, h * w, c] * x]
                        # 3. transpose -> [[b, h*w, 1, c // x] * x]
                        # 4. concat -> [b, h*w, x, c // x]
                        base_index = self._graph.get_next_index()

                        # creates the feature splitter layer
                        num_of_splits = c // out_c
                        feature_splitter = FeatureSplitterLayer()
                        feature_splitter.index = base_index
                        feature_splitter.name = f"feature_splitter_{feature_splitter.index}"
                        feature_splitter.original_names = layer.original_names.copy()
                        feature_splitter.split_sizes = [out_c] * num_of_splits
                        new_layers.append(feature_splitter)
                        self._graph.push_layer(feature_splitter, list(self._graph.predecessors(layer)))

                        feature_splitter.outputs = [layer.name]
                        feature_splitter.output_shapes = [[output_shape[0], h, w, out_c] for _ in range(num_of_splits)]
                        feature_splitter.output_indices = [layer.index]

                        base_index += 1

                        # replaces the general reshape format conversion to spatial flatten
                        original_outputs = layer.outputs.copy()
                        layer.conversion_type = FormatConversionType.spatial_reshape
                        layer.spatial_reshape_sizes = [1, h * w]
                        layer.output_shapes = [[-1, 1, h * w, output_shape[-1]]]
                        # clears original outputs
                        layer.output_indices = []
                        layer.outputs = []

                        spatial_flatten_layers = [layer]
                        # adds more #x-1 spatial flatten layers
                        for i in range(num_of_splits - 1):
                            spatial_flatten = FormatConversionLayer()
                            spatial_flatten.name = f"format_conversion_{base_index}"
                            spatial_flatten.conversion_type = FormatConversionType.spatial_reshape
                            spatial_flatten.spatial_reshape_sizes = layer.spatial_reshape_sizes.copy()
                            spatial_flatten.index = base_index
                            spatial_flatten.original_names = layer.original_names.copy()

                            spatial_flatten.inputs = [feature_splitter.name]
                            spatial_flatten.input_shapes = [feature_splitter.output_shapes[i].copy()]
                            spatial_flatten.input_indices = [feature_splitter.index]

                            spatial_flatten.output_shapes = layer.output_shapes.copy()

                            feature_splitter.append_output_layer(spatial_flatten.name)
                            feature_splitter.append_output_index(spatial_flatten.index)

                            new_layers.append(spatial_flatten)

                            self._graph.add_node(spatial_flatten)
                            self._graph.add_edge(feature_splitter, spatial_flatten)

                            spatial_flatten_layers.append(spatial_flatten)
                            base_index += 1

                        # creates the concat layer
                        concat = ConcatLayer()
                        concat.name = f"concat_{base_index}"
                        concat.original_names = layer.original_names.copy()
                        concat.index = base_index
                        concat.axis = ConcatAxis.spatial_w

                        self._graph.add_node(concat)
                        new_layers.append(concat)
                        base_index += 1

                        for i, spatial_flatten_layer in enumerate(spatial_flatten_layers):
                            # creates the transpose layers
                            transpose = FormatConversionLayer()
                            transpose.conversion_type = FormatConversionType.transpose_height_width
                            transpose.index = base_index + 1
                            transpose.name = f"transpose_{transpose.index}"
                            transpose.original_names = layer.original_names.copy()

                            # assigns the feature splitter as the transpose input
                            transpose.inputs = [spatial_flatten_layer.name]
                            transpose.input_shapes = [spatial_flatten_layer.output_shapes[0]]
                            transpose.input_indices = [spatial_flatten_layer.index]

                            # assigns the transpose as the feature splitter output
                            transpose.output_shapes = [spatial_flatten_layer.output_shapes[0][i] for i in [0, 2, 1, 3]]
                            spatial_flatten_layer.append_output_layer(transpose.name)
                            spatial_flatten_layer.append_output_index(transpose.index)

                            # assigns the concat as the concat output
                            transpose.append_output_layer(concat.name)
                            transpose.append_output_index(concat.index)

                            # assigns the transpose as the concat input
                            concat.append_input_index(transpose.index)
                            concat.append_input_layer(transpose.name)
                            concat.append_input_shape(transpose.output_shapes[0])

                            # connects the graph
                            self._graph.add_node(transpose)
                            self._graph.add_edge(spatial_flatten_layer, transpose)
                            HailoNN.update_successors_meta_data(concat, successors_meta_data)
                            self._graph.add_edge(transpose, concat)
                            new_layers.append(transpose)

                            base_index += 1

                        for succ_name in original_outputs:
                            succ = self._graph.get_layer_by_name(succ_name)
                            concat.append_output_index(succ.index)
                            concat.append_output_layer(succ.name)
                            concat.append_output_shape([-1, out_h, out_w, out_c])

                            # changes the input of the successors of the format conversion layer
                            succ.replace_input_index(layer.index, concat.index)
                            succ.replace_input_layer(layer.name, concat.name)
                            succ.replace_input_shape(layer.name, concat.output_shapes[0])
                            HailoNN.update_successors_meta_data(succ, successors_meta_data)
                            self._graph.remove_edge(layer, succ)
                            self._graph.add_edge(concat, succ)
                    else:
                        raise BackendFuserException(f"Unexpected general reshape form in layer {layer.name}")

        for layer in new_layers:
            self._graph.relax_new_layer_into_graph(layer, successors_meta_data)

    def _handle_split_windowed_attention(self):
        """
        This function replaces format conversion layers thats split the input to windows
        (like done in windowed attention) with equivalent block.

        the format conversion layer will be replaced with:
        input -> spatial_unflatten (if needed) -> space_to_depth (block size of the window shape) -> spatial_reshape -> transpose w->f + groups #c (or h->f)

        shapes explanation:
        [1, h, w, c] -> [1, h * w, c] -> [1, sqrt(number_of_windows), sqrt(number_of_windows), (window_size*c)] -> [1, number_of_windows, 1, number_of_windows*c]
        -> [1, number_of_windows, window_size, c]

        for example:
        window shape: 7x7 -> window size: 49
        input shape = [b, h * w, c] -> [1, 3136, 96]
        number of windows: (h * w) / window_size = (56*56) / 49 = 64
        output shape = [1, number_of_windows, window_size, c] -> [1, 64, 49, 96]
        """
        for layer in list(self._graph):
            if (
                layer.op == LayerType.format_conversion
                and layer.conversion_type == FormatConversionType.split_windowed_attention
            ):
                h = int(layer.attention_params["h"])
                w = int(layer.attention_params["w"])
                f_out = layer.output_shapes[0][-1]

                # manipulating width-features or height-features depending on the model structure
                width_features = layer.attention_params["width_features"]
                if width_features:
                    num_windows = layer.attention_params["num_windows"]
                    window_size = layer.attention_params["window_size"]
                else:
                    num_windows = layer.attention_params["window_size"]
                    window_size = layer.attention_params["num_windows"]

                first_layer = None

                base_index = self._graph.get_next_index()
                preds = list(self.graph.predecessors(layer))
                if len(preds) > 1:
                    continue

                if preds[0].output_shapes[0][1:3] != [h, w]:
                    spatial_unflatten = FormatConversionLayer()
                    spatial_unflatten.name = f"{layer.name}_spatial_reshape{base_index}"
                    spatial_unflatten.conversion_type = FormatConversionType.spatial_reshape
                    spatial_unflatten.spatial_reshape_sizes = [h, w]
                    spatial_unflatten.index = base_index
                    spatial_unflatten.original_names = layer.original_names.copy()
                    spatial_unflatten.input_shapes = layer.input_shapes.copy()
                    spatial_unflatten.output_shapes = [[-1, h, w, f_out]]
                    first_layer = spatial_unflatten
                    base_index += 1

                windowed_reshape = FormatConversionLayer()
                windowed_reshape.name = f"{layer.name}_format_conversion_{base_index}"
                windowed_reshape.conversion_type = FormatConversionType.spatial_reshape
                windowed_reshape.spatial_reshape_sizes = [1, window_size]
                windowed_reshape.input_windows = [int(num_windows**0.5), int(num_windows**0.5), 1]
                windowed_reshape.output_windows = [num_windows, 1, 1]
                windowed_reshape.index = base_index
                windowed_reshape.original_names = layer.original_names.copy()
                windowed_reshape.input_shapes = [[-1, h, w, f_out]]
                windowed_reshape.output_shapes = [-1, num_windows, window_size, f_out]
                first_layer = first_layer if first_layer else windowed_reshape
                base_index += 1

                # connects the new layers to the graph
                self._fuser_helper.replace_succ(preds[0], layer, first_layer)
                self._fuser_helper.add_preds(first_layer, preds, update_input_shapes=False)

                if first_layer != windowed_reshape:
                    self._fuser_helper.add_succs(spatial_unflatten, [windowed_reshape], update_output_shapes=False)
                    self._fuser_helper.add_preds(windowed_reshape, [spatial_unflatten], update_input_shapes=False)

                succs = list(self._graph.successors(layer))
                for succ in succs:
                    self._fuser_helper.replace_pred(succ, layer, windowed_reshape)
                self._fuser_helper.add_succs(windowed_reshape, succs, update_output_shapes=False)

                self.graph.remove_layer(layer)

    def _handle_merge_windowed_attention(self):
        """
        This function replaces format conversion layers thats merge windows to input
        (like done in windowed attention) with equivalent block.

        the format conversion layer will be replaced with:
        input ->  transpose w->f + groups #c (or reshape h->f) -> spatial_reshape -> space_to_depth (block size of the window shape)
        in some cases a spatial flatten will be added.

        shapes explanation:
        [1, number_of_windows, window_size, c] -> [1, number_of_windows, 1, window_size * c]
        -> [1, sqrt(number_of_windows), sqrt(number_of_windows), window_size * c] -> [1, h, w, c] (in some cases -> [1, 1, h*w, c])

        for example:
        window shape: 7x7 -> window size: 49
        number of windows: (h * w) / window_size = (56*56) / 49 = 64
        input shape = [b, number_of_windows, window_size, c] -> [1, 64, 49, 96]
        output shape = [b, sqrt(number_of_windows*window_size), sqrt(number_of_windows*window_size), c] -> [1, 56, 56, 96] (or [1, 1, 3136, 96])
        """
        for layer in list(self._graph):
            if (
                layer.op == LayerType.format_conversion
                and layer.conversion_type == FormatConversionType.merge_windowed_attention
            ):
                width_features = layer.attention_params["width_features"]
                window_h = layer.attention_params["window_h"]
                f_out = layer.output_shapes[0][-1]
                base_index = self._graph.get_next_index()
                num_windows = layer.attention_params["num_windows"]
                window_size = layer.attention_params["window_size"]

                windowed_reshape = FormatConversionLayer()
                windowed_reshape.name = f"{layer.name}_format_conversion_{base_index}"
                windowed_reshape.conversion_type = FormatConversionType.spatial_reshape
                windowed_reshape.spatial_reshape_sizes = [window_h, window_h]
                windowed_reshape.input_windows = [num_windows, 1, 1] if width_features else [1, num_windows, 1]
                windowed_reshape.output_windows = [int(num_windows**0.5), int(num_windows**0.5), 1]
                windowed_reshape.index = base_index
                windowed_reshape.original_names = layer.original_names.copy()
                windowed_reshape.input_shapes = layer.input_shapes.copy()
                windowed_reshape.output_shapes = [
                    -1,
                    int((window_size * num_windows) ** 0.5),
                    int((window_size * num_windows) ** 0.5),
                    f_out,
                ]
                last_layer = windowed_reshape
                base_index += 1

                if layer.attention_params.get("flatten_end", False):
                    spatial_flatten = FormatConversionLayer()
                    spatial_flatten.name = f"{layer.name}_spatial_reshape{base_index}"
                    spatial_flatten.conversion_type = FormatConversionType.spatial_reshape
                    out_shape = [1, int(layer.attention_params["h"] * layer.attention_params["w"]), f_out]
                    spatial_flatten.spatial_reshape_sizes = out_shape
                    spatial_flatten.index = base_index
                    spatial_flatten.original_names = layer.original_names.copy()
                    spatial_flatten.input_shapes = windowed_reshape.output_shapes.copy()
                    spatial_flatten.output_shapes = [[-1, *out_shape]]
                    last_layer = spatial_flatten
                    base_index += 1

                # connects the new layers to the graph
                preds = list(self.graph.predecessors(layer))
                self._fuser_helper.replace_succ(preds[0], layer, windowed_reshape)
                self._fuser_helper.add_preds(windowed_reshape, preds, update_input_shapes=False)

                if last_layer != windowed_reshape:
                    self._fuser_helper.add_succs(windowed_reshape, [spatial_flatten], update_output_shapes=False)
                    self._fuser_helper.add_preds(spatial_flatten, [windowed_reshape], update_input_shapes=False)

                succs = list(self._graph.successors(layer))
                for succ in succs:
                    self._fuser_helper.replace_pred(succ, layer, last_layer)
                self._fuser_helper.add_succs(last_layer, succs, update_output_shapes=False)

                self.graph.remove_layer(layer)

    def _handle_groups_to_spatial_flatten(self):
        """
        This function replaces format conversion of groups to spatial flatten with equivalent block.
        input shape: [n, h, w, g, c]
        output shape: [n, 1, h*w*g, c] / [n, 1, h*w*g1, c*g2]
        """
        for layer in list(self._graph):
            if layer.op == LayerType.format_conversion and layer.conversion_type in [
                FormatConversionType.groups_to_spatial_flatten,
                FormatConversionType.partial_groups_to_spatial_flatten,
            ]:
                groups = layer.groups
                h, w, f_out = layer.input_shape[1:]
                f_out //= groups  # the output features are divided by the number of groups
                base_index = self._graph.get_next_index()

                spatial_reshape = FormatConversionLayer()
                spatial_reshape.name = f"{layer.name}_spatial_reshape{base_index}"
                spatial_reshape.conversion_type = FormatConversionType.spatial_reshape
                spatial_reshape.spatial_reshape_sizes = [1, h * w]
                spatial_reshape.index = base_index
                spatial_reshape.original_names = layer.original_names.copy()
                spatial_reshape.input_shapes = layer.input_shapes.copy()
                spatial_reshape.output_shapes = [[-1, *spatial_reshape.spatial_reshape_sizes, f_out * groups]]
                base_index += 1

                depth2space = DepthToSpaceLayer()
                depth2space.name = f"{layer.name}_depth_to_space{base_index}"
                depth2space.depth_to_space_type = DepthToSpaceType.dcr
                depth2space.block_sizes = [groups, 1]
                depth2space.index = base_index
                depth2space.original_names = layer.original_names.copy()
                depth2space.input_shapes = spatial_reshape.output_shapes.copy()
                depth2space.output_shapes = [[-1, groups, h * w, f_out]]
                base_index += 1

                spatial_reshape2 = FormatConversionLayer()
                spatial_reshape2.name = f"{layer.name}_spatial_reshape{base_index}"
                spatial_reshape2.conversion_type = FormatConversionType.spatial_reshape
                spatial_reshape2.spatial_reshape_sizes = [1, h * w * groups]
                spatial_reshape2.index = base_index
                spatial_reshape2.original_names = layer.original_names.copy()
                spatial_reshape2.input_shapes = depth2space.output_shapes.copy()
                spatial_reshape2.output_shapes = [[-1, *spatial_reshape2.spatial_reshape_sizes, f_out]]
                base_index += 1

                # connects the new layers to the graph
                succs = list(self._graph.successors(layer))
                self._fuser_helper.add_succs(spatial_reshape, [depth2space], update_output_shapes=False)
                self._fuser_helper.add_succs(depth2space, [spatial_reshape2], update_output_shapes=False)
                self._fuser_helper.add_succs(spatial_reshape2, succs, update_output_shapes=False)

                preds = list(self._graph.predecessors(layer))
                self._fuser_helper.add_preds(spatial_reshape, preds, update_input_shapes=False)
                self._fuser_helper.add_preds(depth2space, [spatial_reshape], update_input_shapes=False)
                self._fuser_helper.add_preds(spatial_reshape2, [depth2space], update_input_shapes=False)

                for pred in list(self._graph.predecessors(layer)):
                    self._fuser_helper.replace_succ(pred, layer, spatial_reshape)

                for succ in succs:
                    self._fuser_helper.replace_pred(succ, layer, spatial_reshape2)
                self.graph.remove_layer(layer)

    def _handle_spatial_flatten_to_groups(self):
        """
        This function replaces format conversion of spatial flatten to groups with equivalent block.
        input shape: [n, 1, h*w*g, c]
        output shape: [n, h, w, g, c]
        """
        for layer in list(self._graph):
            if (
                layer.op == LayerType.format_conversion
                and layer.conversion_type == FormatConversionType.spatial_flatten_to_groups
            ):
                groups = layer.groups
                h_out = w_out = int((layer.input_shape[2] // groups) ** 0.5)
                f_out = layer.input_shape[-1] * groups
                base_index = self._graph.get_next_index()

                spatial_reshape = FormatConversionLayer()
                spatial_reshape.name = f"{layer.name}_spatial_reshape{base_index}"
                spatial_reshape.conversion_type = FormatConversionType.spatial_reshape
                spatial_reshape.spatial_reshape_sizes = [groups, h_out * w_out]
                spatial_reshape.index = base_index
                spatial_reshape.original_names = layer.original_names.copy()
                spatial_reshape.input_shapes = layer.input_shapes.copy()
                spatial_reshape.output_shapes = [[-1, *spatial_reshape.spatial_reshape_sizes, layer.input_shape[-1]]]
                base_index += 1

                space2depth = SpaceToDepthLayer()
                space2depth.name = f"{layer.name}_depth_to_space{base_index}"
                space2depth.depth_to_space_type = DepthToSpaceType.dcr
                space2depth.block_sizes = [groups, 1]
                space2depth.index = base_index
                space2depth.original_names = layer.original_names.copy()
                space2depth.input_shapes = spatial_reshape.output_shapes.copy()
                space2depth.output_shapes = [[-1, 1, h_out * w_out, f_out]]
                base_index += 1

                spatial_reshape2 = FormatConversionLayer()
                spatial_reshape2.name = f"{layer.name}_spatial_reshape{base_index}"
                spatial_reshape2.conversion_type = FormatConversionType.spatial_reshape
                spatial_reshape2.spatial_reshape_sizes = [w_out, h_out]
                spatial_reshape2.index = base_index
                spatial_reshape2.original_names = layer.original_names.copy()
                spatial_reshape2.input_shapes = space2depth.output_shapes.copy()
                spatial_reshape2.output_shapes = [[-1, *spatial_reshape2.spatial_reshape_sizes, f_out]]
                base_index += 1

                # connects the new layers to the graph
                succs = list(self._graph.successors(layer))
                self._fuser_helper.add_succs(spatial_reshape, [space2depth], update_output_shapes=False)
                self._fuser_helper.add_succs(space2depth, [spatial_reshape2], update_output_shapes=False)
                self._fuser_helper.add_succs(spatial_reshape2, succs, update_output_shapes=False)

                preds = list(self._graph.predecessors(layer))
                self._fuser_helper.add_preds(spatial_reshape, preds, update_input_shapes=False)
                self._fuser_helper.add_preds(space2depth, [spatial_reshape], update_input_shapes=False)
                self._fuser_helper.add_preds(spatial_reshape2, [space2depth], update_input_shapes=False)

                for pred in list(self._graph.predecessors(layer)):
                    self._fuser_helper.replace_succ(pred, layer, spatial_reshape)

                for succ in succs:
                    self._fuser_helper.replace_pred(succ, layer, spatial_reshape2)
                self.graph.remove_layer(layer)

    def _handle_downsample_by_two_with_slice(self):
        """
        This function replaces slice layers that downsample by two (h, w) the input with depthwise conv layers
        with kernel that takes the upper left element of the input tensor.
        kernel shape is [2, 2, f_in, 1] and strides are [1, 2, 2, 1].
        """
        layers_to_remove = []
        for layer in list(self._graph):
            if layer.op == LayerType.base_slice:
                if (
                    2 in [layer.height_slice[-1], layer.width_slice[-1]]  # downsamples h, w by 2
                    and layer.height_slice[0] in [0, 1]
                    and layer.width_slice[0] in [0, 1]
                    and layer.height_slice[-2] == layer.width_slice[-2] == layer.input_shape[-2]  # the whole tensor
                    and layer.features_slice[-1] == 1
                ):
                    dw_conv = Conv2DLayer()
                    dw_conv.name = f"{layer.name}_dw_conv"
                    dw_conv.index = self._graph.get_next_index()
                    dw_conv.original_names = layer.original_names.copy()
                    dw_conv.input_shapes = layer.input_shapes.copy()
                    dw_conv.output_shapes = layer.output_shapes.copy()
                    dw_conv.op = LayerType.base_dw
                    dw_conv.padding = PaddingType.same_tensorflow
                    f_in = layer.input_shapes[0][-1]
                    dw_conv.bias = np.zeros(f_in, dtype=np.float32)
                    dw_conv.kernel_shape = [2, 2, f_in, 1]
                    # creates kernel that take the upper left element of the input tensor to simulate the strided slice
                    # operation, also has to padd on the right bottom thus padding is same_tensorflow
                    base_kernel = np.zeros((2, 2))
                    base_kernel[layer.height_slice[0], layer.width_slice[0]] = 1
                    base_kernel = base_kernel[:, :, None, None]
                    dw_conv.kernel = np.repeat(base_kernel, f_in, axis=-2)
                    dw_conv.strides = [1, layer.height_slice[-1], layer.width_slice[-1], 1]
                    dw_conv.dilations = [1, 1, 1, 1]

                    # connects the new layer to the graph
                    succs = list(self._graph.successors(layer))
                    self._fuser_helper.add_succs(dw_conv, succs)
                    for succ in succs:
                        self._fuser_helper.replace_pred(succ, layer, dw_conv)

                    preds = list(self._graph.predecessors(layer))
                    self._fuser_helper.add_preds(dw_conv, preds)
                    for pred in preds:
                        self._fuser_helper.replace_succ(pred, layer, dw_conv)

                    layers_to_remove.append(layer)

        for layer_to_remove in layers_to_remove:
            self._graph.remove_layer(layer_to_remove)

        # verifies that there are no strided slices in the graph
        for layer in self._graph:
            if layer.op != LayerType.base_slice:
                continue

            if any(slice_args[2] != 1 for slice_args in [layer.height_slice, layer.width_slice, layer.features_slice]):
                raise BackendFuserException(
                    f"{layer.full_name_msg} is unsupported due to strided slicing: height_slice={layer.height_slice}, "
                    f"width_slice={layer.width_slice}, features_slice={layer.features_slice}",
                )

    def _add_conv_before_dynamic_weights_layer(self):
        """
        This method adds either:
            1. conv before matmul with transpose input and other layers with dynamic weights to enable zp correction.
            2. standalone activation for the second attention matmul (without transpose input) to allow uint to int
               conversion.
        The method will be removed in SDK-47146.
        """
        new_layers = []
        successors_meta_data = {}

        allowed_pred_ops = [
            LayerType.conv,
            LayerType.dw,
            LayerType.slice,
            LayerType.feature_splitter,
            LayerType.normalization,
            LayerType.ew_sub,
        ]
        non_modifying_range_ops = [LayerType.format_conversion, LayerType.base_slice]

        for layer in list(filter(lambda x: x.dynamic_weights and x.op != LayerType.dw, self.output_graph)):
            preds = list(self.output_graph.predecessors(layer))
            redundant_addition = False
            for pred in preds:
                cur_pred_preds = list(self.output_graph.predecessors(pred))
                allowed_pred_op = pred.op in allowed_pred_ops or (
                    len(cur_pred_preds) == 1
                    and pred.op in non_modifying_range_ops
                    and cur_pred_preds[0].op in allowed_pred_ops
                )
                allowed_number_of_outputs = (
                    pred.op == LayerType.feature_splitter or len(list(self.output_graph.successors(pred))) == 1
                )
                if allowed_pred_op and allowed_number_of_outputs:
                    redundant_addition = True
                    break

                weights_pred = pred

            if redundant_addition or (layer.op == LayerType.dw and layer.groups != 1):
                continue

            is_matmul_without_transpose_input = layer.op == LayerType.matmul and not layer.transpose_matmul_input
            layer_class = FusedStandaloneActivationLayer if is_matmul_without_transpose_input else FusedConv2DLayer

            output_shape = layer.input_shapes[1]
            weights_layer = self._fuser_helper.create_layer(
                layer_class,
                self.output_graph.get_next_index(),
                "weights_layer",
                layer,
                new_layers,
                [output_shape],
            )
            if weights_layer.op == LayerType.conv:
                weights_layer.op = LayerType.dw
                weights_layer.padding = PaddingType.valid
                input_features = output_shape[-1]
                weights_layer.bias = np.zeros(input_features, dtype=np.float32)
                weights_layer.kernel_shape = [1, 1, input_features, 1]
                weights_layer.kernel = np.reshape(
                    np.ones(input_features, dtype=np.float32),
                    weights_layer.kernel_shape,
                )
                weights_layer.strides = [1, 1, 1, 1]
                weights_layer.dilations = [1, 1, 1, 1]

            self._fuser_helper.add_preds(weights_layer, [weights_pred])
            self._fuser_helper.add_succs(weights_layer, [layer])
            self._fuser_helper.replace_pred(layer, weights_pred, weights_layer)
            self._fuser_helper.replace_succ(weights_pred, layer, weights_layer)

            self._logger.debug(f"Inserted conv layer to enable dynamic weights in layer {layer.name}")

        for layer in new_layers:
            self.output_graph.relax_new_layer_into_graph(layer, successors_meta_data)

    def _add_shortcut_to_empty_model(self):
        new_layers = []
        successors_meta_data = {}

        for layer in list(self.output_graph):
            if layer.op != LayerType.input_layer:
                continue

            self._fuser_helper.replace_layer_with_shortcut(layer, new_layers, in_out_shortcut=True)

            self._logger.debug("Inserted shortcut layer between input and output layers")

        for layer in new_layers:
            self.output_graph.relax_new_layer_into_graph(layer, successors_meta_data)

    def _handle_dense_before_flat_to_frame(self):
        """
        This function replaces dense followed by flat to frames with conv1x1.
        """
        layers_to_remove = []

        for layer in self._output_graph:
            if layer.op != LayerType.format_conversion or layer.conversion_type != FormatConversionType.flat_to_frames:
                continue

            dense_preds = [pred for pred in self._output_graph.predecessors(layer) if pred.op == LayerType.dense]
            if len(dense_preds) != 1:
                continue

            dense_preds_path = self._find_dense_predecessors_path(dense_preds)
            for dense in dense_preds_path:
                FuserHelper.replace_dense_with_conv1x1(dense, model=self._output_graph)

            self._fuser_helper.remove_layer(layer, layers_to_remove)

        for layer in layers_to_remove:
            self._output_graph.remove_layer(layer)

    def _find_dense_predecessors_path(self, layers):
        for layer in layers:
            if layer.op != LayerType.dense:
                return []
            return [layer, *self._find_dense_predecessors_path(list(self._output_graph.predecessors(layer)))]
        return []

    def _handle_neg_feature_shuffle(self):
        """
        This function detects a feature shuffle with neg norm (-1) block (QWEN block variations),
        injects the -1 into the ew_mult const input, and replaces the feature_splitter & concat with feature_shuffle
        """
        new_layers = []
        layers_to_remove = []
        for layer in list(self._graph):
            if layer.op != LayerType.feature_splitter:
                continue

            feature_splitter_preds = list(self._graph.predecessors(layer))
            succs = list(self._graph.successors(layer))
            succs_ops = [succ.op for succ in succs]
            if (
                len(feature_splitter_preds) != 1
                or len(succs) != 2
                or not (LayerType.normalization in succs_ops and LayerType.concat in succs_ops)
            ):
                continue

            neg_norm = succs[0] if succs[0].op == LayerType.normalization else succs[1]
            concat = succs[0] if succs[0].op == LayerType.concat else succs[1]
            concat_succs = list(self._graph.successors(concat))
            if len(concat_succs) != 1 or concat_succs[0].op != LayerType.ew_mult:
                continue

            ew_mult = concat_succs[0]
            ew_mult_preds = list(self._graph.predecessors(ew_mult))
            is_const_input_ew_mult = len(ew_mult_preds) == 2 and any(
                pred.op == LayerType.const_input for pred in ew_mult_preds
            )
            is_neg_normalization = np.all(neg_norm.kernel == -1) and np.all(neg_norm.bias == 0)

            if (
                len(concat.group_sizes) != layer.groups
                or concat.axis != ConcatAxis.features
                or not is_const_input_ew_mult
                or not is_neg_normalization
            ):
                continue

            # inject the neg value into the ew_mult const input data
            ew_mult_const_input = ew_mult_preds[0] if ew_mult_preds[0].op == LayerType.const_input else ew_mult_preds[1]
            input_values_shape = ew_mult_const_input.const_values.shape
            # create a vector where the first half of features is 1 and the second half is -1
            half_elements = input_values_shape[-1] // len(layer.output_shapes)
            neg_vector = np.ones(input_values_shape)
            neg_vector[..., :half_elements] = -1
            # multiply the shuffled 1,-1 tensor by the original input_values
            ew_mult_const_input.const_values = ew_mult_const_input.const_values * neg_vector
            # add the negative constant input to the constant input original names, remove from feature shuffle
            ew_mult_const_input.original_names.append(neg_norm.name)

            # generate feature shuffle layer
            feature_shuffle = self._fuser_helper.create_layer(
                FeatureShuffleLayer,
                self._graph.get_next_index(),
                "feature_shuffle",
                layer,
                new_layers,
                [layer.input_shape],
            )
            feature_shuffle.groups = layer.groups
            start = layer.output_shapes[0][-1] // layer.groups
            end = start + layer.output_shapes[1][-1] // layer.groups
            feature_shuffle.groups_slice = [start, end, 1]
            feature_shuffle.output_shapes = concat.output_shapes.copy()
            feature_shuffle.input_shapes = layer.input_shapes.copy()
            feature_shuffle.original_names.extend(concat.original_names)
            # replace the feature shuffle building block with a single layer
            self._fuser_helper.replace_pred(ew_mult, concat, feature_shuffle)
            self._fuser_helper.replace_succ(feature_splitter_preds[0], layer, feature_shuffle)
            self._fuser_helper.add_preds(feature_shuffle, [feature_splitter_preds[0]])
            self._fuser_helper.add_succs(feature_shuffle, [ew_mult])
            layers_to_remove.extend([neg_norm, concat, layer])

        for layer_to_remove in layers_to_remove:
            self._graph.remove_layer(layer_to_remove)
