#!/usr/bin/env python

import itertools
import re

import networkx as nx
import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DEFAULT_BOX_AND_OBJ_PXLS,
    DEFAULT_CONCAT_AXIS,
    BBoxDecodersInfo,
    ConcatAxis,
    NMSProperties,
)
from hailo_sdk_client.exposed_definitions import Dims, NNFramework
from hailo_sdk_client.model_translator.edge_nn_translator import EdgeNNConverter, VertexState
from hailo_sdk_client.model_translator.exceptions import (
    RecordableParserError,
    UnexpectedNodeError,
    UnsupportedActivationLayerError,
    UnsupportedConstInputError,
    UnsupportedEqualLayerError,
    UnsupportedFormatConversionLayerError,
    UnsupportedLogitsLayerError,
    UnsupportedModelError,
    UnsupportedOperationError,
    UnsupportedReduceL2LayerError,
    UnsupportedReduceMaxLayerError,
    UnsupportedReduceMinLayerError,
    UnsupportedReduceSumLayerError,
    UnsupportedReduceSumSquareLayerError,
    UnsupportedResizeLayerError,
    UnsupportedShuffleLayerError,
    UnsupportedSpaceToDepthLayerError,
    UnsupportedSquareLayerError,
    UnsupportedTileLayerError,
)
from hailo_sdk_client.model_translator.graph_lookup import (
    BwdChainNode,
    FwdChainNode,
    get_all_nodes_from_possible_chains,
)
from hailo_sdk_client.model_translator.onnx_translator.exceptions import NoParamsModelError
from hailo_sdk_client.model_translator.onnx_translator.onnx_graph import (
    ACTIVATION_OPS,
    ADD_OPS,
    ALTERNATIVE_SOFTMAX_OPS,
    BN_INPUT_ORDER,
    BN_OPS,
    CONCAT_OPS,
    CONV2D_INPUT_ORDER,
    CONV2D_OPS,
    DENSE_OPS,
    DIV_OPS,
    EINSUM_OPS,
    EQUAL_OPS,
    EW_OPS,
    GATHER_OPS,
    GRU_OPS,
    INSTANCE_NORMALIZATION_OPS,
    LAYER_NORMALIZATION_OPS,
    LOG_SOFTMAX_OPS,
    LOGITS_OPS,
    LSTM_OPS,
    MAX_OPS,
    MIN_OPS,
    MUL_OPS,
    NEG_OPS,
    NMS_OPS,
    ONE_HOT_OPS,
    PAD_OPS,
    POOL_OPS,
    POW_OPS,
    REDUCE_L2_OPS,
    REDUCE_MAX_OPS,
    REDUCE_MIN_OPS,
    REDUCE_SUM_OPS,
    REDUCE_SUM_SQUARE_OPS,
    RESIZE_OPS,
    RNN_OPS,
    SCATTER_ND_OPS,
    SHUFFLE_OPS,
    SKIP_OPS,
    SLICE_OPS,
    SPLIT_OPS,
    SUB_OPS,
    SUPPORTED_OPS_UNION,
    TILE_OPS,
    ONNXGraph,
)
from hailo_sdk_client.model_translator.onnx_translator.onnx_translator_definitions import Conv2DInfo
from hailo_sdk_client.sdk_backend.script_parser.nms_postprocess_command import (
    YOLO_OUTPUTS_PER_BRANCH,
    YOLOV6_TOTAL_OUTPUTS,
    YOLOX_ACTIVATIONS_PER_REG_LAYER,
    YOLOX_TOTAL_OUTPUTS,
    NMSPostprocessCommand,
    get_f_out_by_meta_arch,
)
from hailo_sdk_client.tools.layers.layers_utils import calculate_padding
from hailo_sdk_common.hailo_nn.exceptions import RecordableCreateLayerError
from hailo_sdk_common.hailo_nn.hn_definitions import (
    ActivationType,
    DepthToSpaceType,
    FeatureMultiplierType,
    FormatConversionType,
    LayerType,
    NMSMetaArchitectures,
    NormalizationType,
    PaddingType,
    ResizeBilinearPixelsMode,
    ResizeMethod,
    SpaceToDepthType,
    TemporaryPaddingType,
)
from hailo_sdk_common.hailo_nn.hn_layers import (
    ActivationLayer,
    ArgmaxLayer,
    BatchNormLayer,
    BiasAddLayer,
    ConcatLayer,
    ConstInputLayer,
    Conv2DLayer,
    DenseLayer,
    DepthToSpaceLayer,
    EqualLayer,
    EWAddLayer,
    EWDivLayer,
    EWMaxLayer,
    EWMeanLayer,
    EWMinLayer,
    EWMultLayer,
    EWSubLayer,
    ExternalPadLayer,
    FeatureMultiplierLayer,
    FeatureShuffleLayer,
    FeatureSplitterLayer,
    FormatConversionLayer,
    GRULayer,
    L2NormalizationLayer,
    LayerNormalizationLayer,
    LogSoftmaxLayer,
    LSTMLayer,
    MatmulLayer,
    NormalizationLayer,
    NullLayer,
    OneHotLayer,
    PoolingLayer,
    ReduceL2Layer,
    ReduceMaxLayer,
    ReduceMeanLayer,
    ReduceMinLayer,
    ReduceSumLayer,
    ReduceSumSquareLayer,
    ResizeLayer,
    RNNLayer,
    ScatterNDLayer,
    ShortcutLayer,
    SliceLayer,
    SoftmaxLayer,
    SpaceToDepthLayer,
    SpatialSplitterLayer,
    TransposeLayer,
    WidthSplitterLayer,
)
from hailo_sdk_common.hailo_nn.hn_layers.einsum import EinsumLayer
from hailo_sdk_common.hailo_nn.nms_postprocess_defaults import DEFAULT_YOLO_ANCHORS

NMS_COMMON_LAYERS_DEPTH = 20  # Maximum depth to traverse backwards from the output node


class ONNXConverter(EdgeNNConverter):
    def __init__(self, model, values, output_shapes, start_node_names=None, end_node_names=None, net_input_format=None):
        onnx_graph = model.graph.node
        net_input = model.graph.input
        net_output = model.graph.output
        values = model.graph.initializer
        tensor_shapes = model.graph.value_info
        end_node_names = self._get_real_end_node_names(onnx_graph, net_output, end_node_names)

        super().__init__(
            graph=ONNXGraph(
                graph=onnx_graph,
                values=values,
                net_input=net_input,
                tensor_shapes=tensor_shapes,
                output_shapes=output_shapes,
                opset_version=model.opset_import[0].version,
                net_input_format=net_input_format,
            ),
            start_node_names=start_node_names,
            end_node_names=end_node_names,
        )

        self._resize_layers_meta_vertices = {}
        self._net_output = net_output
        self._meta_graph = None
        self._nn_framework = NNFramework.ONNX

    def _handle_fused_layers(self):
        self._separate_external_paddings()

    def _calculate_shapes(self, validate_shapes=True):
        self._update_attention_windows()
        self._update_meta_graph()
        self._layers_graph.calculate_shapes(meta_edges_graph=self._meta_graph, validate_shapes=validate_shapes)

    def _update_attention_windows(self):
        """
        This function propagates the number of windows from the split windowed attention layer,
        the number of windows should be taken into account in calculating of the output shape of the matmul layers.
        """
        for layer in list(self._layers_graph):
            if (
                layer.op == LayerType.format_conversion
                and layer.conversion_type == FormatConversionType.split_windowed_attention
            ):
                # updates the number of windows in each row and column in mul((QK^T), V)
                possible_chains = [
                    [
                        FwdChainNode(op=LayerType.base_conv),
                        FwdChainNode(op=LayerType.null),
                        FwdChainNode(op=LayerType.null),
                        FwdChainNode(op=LayerType.matmul),
                        BwdChainNode(op=LayerType.softmax),
                    ],
                    [
                        FwdChainNode(op=LayerType.base_conv),
                        FwdChainNode(op=LayerType.feature_splitter),
                        FwdChainNode(op=LayerType.normalization),
                        FwdChainNode(op=LayerType.null),
                        FwdChainNode(op=LayerType.null),
                        FwdChainNode(op=LayerType.matmul),
                        BwdChainNode(op=LayerType.softmax),
                    ],
                    [
                        FwdChainNode(op=LayerType.base_conv),
                        FwdChainNode(op=LayerType.null),
                        FwdChainNode(op=LayerType.null),
                        FwdChainNode(op=LayerType.feature_splitter),
                        FwdChainNode(op=LayerType.null),
                        FwdChainNode(op=LayerType.matmul),
                        BwdChainNode(op=LayerType.softmax),
                    ],
                    [
                        FwdChainNode(op=LayerType.base_conv),
                        FwdChainNode(op=LayerType.null),
                        FwdChainNode(op=LayerType.null),
                        FwdChainNode(op=LayerType.base_slice),
                        FwdChainNode(op=LayerType.matmul),
                        BwdChainNode(op=LayerType.softmax),
                    ],
                    [
                        FwdChainNode(op=LayerType.base_conv),
                        FwdChainNode(op=LayerType.null),
                        FwdChainNode(op=LayerType.null),
                        FwdChainNode(op=LayerType.feature_splitter),
                        FwdChainNode(op=LayerType.null),
                        FwdChainNode(op=LayerType.matmul),
                        BwdChainNode(op=LayerType.softmax),
                    ],
                    [
                        FwdChainNode(op=LayerType.layer_normalization),
                        FwdChainNode(op=LayerType.base_conv),
                        FwdChainNode(op=LayerType.null),
                        FwdChainNode(op=LayerType.feature_splitter),
                        FwdChainNode(op=LayerType.null),
                        FwdChainNode(op=LayerType.matmul),
                        BwdChainNode(op=LayerType.softmax),
                    ],
                    [
                        FwdChainNode(op=LayerType.layer_normalization),
                        FwdChainNode(op=LayerType.normalization),
                        FwdChainNode(op=LayerType.base_conv),
                        FwdChainNode(op=LayerType.null),
                        FwdChainNode(op=LayerType.null),
                        FwdChainNode(op=LayerType.feature_splitter),
                        FwdChainNode(op=LayerType.null),
                        FwdChainNode(op=LayerType.matmul),
                        BwdChainNode(op=LayerType.softmax),
                    ],
                ]
                chain = get_all_nodes_from_possible_chains(self._layers_graph, layer, possible_chains, exact_match=True)
                if not chain:
                    msg = f"Couldn't resolve shape calculations in windowed attention block near {layer.name}"
                    raise ValueError(msg)
                input_windows = (
                    layer.attention_params["num_windows"]
                    if layer.attention_params["width_features"]
                    else layer.attention_params["window_size"]
                )
                matmul = chain[-2]
                matmul.input_windows = [input_windows, 1, 1]
                matmul.output_windows = [input_windows, 1, 1]
                # updates the number of input_windows in each row and column in mul(Q, K^T)
                pred = next(iter(self._layers_graph.predecessors(matmul)))
                while pred.op != LayerType.matmul:
                    pred = next(iter(self._layers_graph.predecessors(pred)))
                pred.input_windows = [input_windows, 1, 1]
                pred.output_windows = [input_windows, 1, 1]

    def _update_meta_graph(self):
        if self._resize_layers_meta_vertices:
            for resize_layer, resize_vertices in self._resize_layers_meta_vertices.items():
                resize_layer.resize_layers = [self._vertices_to_layers[x] for x in resize_vertices]

            # meta graph only holds pointers to layers in the original hn, adding the edge collection implicitly adds
            # shallow copied hn layers as nodes (so no need to add each layer as well).
            self._meta_graph = nx.DiGraph()
            self._meta_graph.add_edges_from(self._layers_graph.edges)
            for layer in list(self._layers_graph):
                if layer.op == LayerType.resize and layer.resize_layers:
                    for resize_layer_dependency in layer.resize_layers:
                        self._meta_graph.add_edge(resize_layer_dependency, layer)

    def _update_vertices_info(self):
        found_non_const_padding = False
        for node in self.graph.nodes_toposorted():
            node.update_output_format()
            if node.op in POOL_OPS and node.is_global_pool():
                node.is_spatial_1x1 = True
            elif node.op in PAD_OPS and node.is_non_const_padding():
                found_non_const_padding = True

        if found_non_const_padding:
            self._logger.warning(
                "This model has non-default (reflective/edge) padding layers which are not supported currently, "
                "and were replaced with zero padding. When the padding precedes pooling layers, we expect slight "
                "degradation in the parsed model. For more info about padding modes, please refer to "
                "https://github.com/onnx/onnx/blob/main/docs/Changelog.md#pad-11",
            )

    @staticmethod
    def _get_real_end_node_names(onnx_proto, net_output, end_node_names):
        if end_node_names:
            graph_outputs = [output.name for output in net_output]
            for i, end_node_name in enumerate(end_node_names):
                if end_node_name in graph_outputs:
                    for node in onnx_proto:
                        if end_node_name in node.output:
                            end_node_names[i] = node.name
        else:
            end_node_names = []
            for output in net_output:
                for node in onnx_proto:
                    if output.name in node.output:
                        end_node_names.append(node.name)
        return end_node_names

    def _add_output_layers(self, fused_activations=None):
        super()._add_output_layers()
        self._layers_graph.set_names_and_indices()
        for layer in list(self._layers_graph):
            if layer.op == LayerType.output_layer:
                pred = next(iter(self._layers_graph.predecessors(layer)))
                end_node_orig_names = [
                    orig_name for orig_name in pred.original_names if orig_name in self._end_node_names
                ]
                if not end_node_orig_names:
                    continue
                end_node = self.graph.get_vertex_by_name(end_node_orig_names[0])
                for node_output in end_node.output:
                    if node_output in [out.name for out in self._net_output]:
                        layer.original_names = [node_output]

    def _layer_callback_from_vertex(self, vertex):
        consumed_vertices = []
        should_assign_vertex_to_layer = True
        try:
            # first update if op is affected by batch-first shape transformation (multi-head attention)
            is_alternative_softmax = vertex.op in ALTERNATIVE_SOFTMAX_OPS and vertex.is_softmax()
            is_transposed_batch_norm = vertex.is_transposed_batch_norm()
            is_flattened_global_maxpool = vertex.is_flattened_global_maxpool()  # edge case where flatten isn't null op
            is_einsum_conv1x1 = vertex.is_einsum_conv1x1()
            is_grouped_reduce_sum_gather = vertex.is_grouped_reduce_sum_gather()
            is_layer_norm = vertex.is_layer_norm()
            is_regular_gather_slice = vertex.is_gather_slice()
            is_channel_shuffle_gather_slice = vertex.is_channel_shuffle_gather_slice()

            if vertex.op not in SUPPORTED_OPS_UNION:
                msg = f"{vertex.op} operation is unsupported"
                raise UnsupportedOperationError(msg)

            if vertex.is_null_operation() and not is_flattened_global_maxpool:
                consumed_vertices = self._create_null_layer(vertex)
            elif vertex.op in CONV2D_OPS:
                if vertex.is_conv3d():
                    consumed_vertices = self._create_conv3d_layer(vertex)
                elif vertex.is_matmul_layer():
                    self._create_matmul_layer(vertex)
                else:
                    consumed_vertices = self._create_convolutional_layer(vertex)
            elif vertex.op in BN_OPS or is_transposed_batch_norm:
                consumed_vertices = self._create_batch_norm_layer(vertex, is_transposed_batch_norm)
            elif vertex.op in DENSE_OPS:
                if vertex.is_conv1x1_matmul() or vertex.is_conv1x1_dense():
                    consumed_vertices = self._create_convolutional_layer(vertex)
                elif vertex.is_matmul_layer():
                    self._create_matmul_layer(vertex)
                else:
                    consumed_vertices = self._create_dense_layer(vertex)
            elif vertex.op in POOL_OPS and not is_layer_norm:
                if vertex.is_grouped_reduce_max():
                    consumed_vertices = self._create_grouped_reduce_max_layer(vertex)
                elif vertex.is_ew_mean():
                    self._create_ew_mean_layer(vertex)
                elif vertex.is_reduce_mean_layer():
                    consumed_vertices = self._create_reduce_mean_layer(vertex)
                else:
                    consumed_vertices = self._create_pooling_layer(vertex)
            elif vertex.op in LOGITS_OPS or is_alternative_softmax:
                consumed_vertices = self._create_logits_layer(vertex, is_alternative_softmax=is_alternative_softmax)
            elif vertex.op in MAX_OPS and vertex.is_ew_max():
                self._create_ew_max_layer(vertex)
            elif vertex.op in MIN_OPS and vertex.is_ew_min():
                self._create_ew_min_layer(vertex)
            elif (
                vertex.op in ACTIVATION_OPS
                or (vertex.op in DIV_OPS and (vertex.is_inv_pos_activation()[0] or vertex.is_gelu_activation()[0]))
                or (vertex.op in ADD_OPS and vertex.is_hardswish_activation()[0])
                or (vertex.op in MUL_OPS and vertex.is_swish_activation()[0])
                or (vertex.op in POW_OPS and vertex.is_decimal_fraction_pow_activation()[0])
                or (vertex.op in SUB_OPS and vertex.is_hardsigmoid())
            ):
                if vertex.is_decomposed_l2_norm():
                    consumed_vertices = self._create_l2_normalization_layer(vertex)
                else:
                    consumed_vertices = self._create_activation_layer(vertex)
            elif vertex.op in EW_OPS + NEG_OPS and vertex.is_normalization():
                consumed_vertices = self._create_normalization_layer(vertex)
            elif vertex.op in ADD_OPS:
                if vertex.is_ew_add():
                    self._create_ew_add_layer(vertex)
                elif vertex.is_mul_by_2_ew_add():
                    self._create_normalization_layer(vertex)
                else:
                    self._create_bias_add_layer(vertex)
            elif vertex.op in SUB_OPS and vertex.is_ew_sub():
                self._create_ew_sub_layer(vertex)
            elif vertex.op in LAYER_NORMALIZATION_OPS or (
                vertex.op in MUL_OPS + POW_OPS + ["ReduceMean"] and is_layer_norm
            ):
                consumed_vertices = self._create_layer_normalization_layer(vertex)
            elif vertex.op in MUL_OPS and vertex.is_square() or vertex.op in POW_OPS:
                consumed_vertices = self._create_square_layer(vertex)
            elif vertex.op in MUL_OPS:
                self._create_ew_mult_layer(vertex)
            elif vertex.op in CONCAT_OPS:
                consumed_vertices = self._create_concat_layer(vertex)
            elif vertex.op in SPLIT_OPS:
                if vertex.is_spatial_splitter():
                    consumed_vertices = self._create_spatial_split_layer(vertex)
                else:
                    consumed_vertices = self._create_feature_split_layer(vertex)
            elif (
                vertex.op in [*SHUFFLE_OPS, *TILE_OPS, "Flatten", "Unsqueeze"]
                and not vertex.is_successive_unsqueeze_flat_to_frame()
            ):
                is_instance_norm, is_group_norm = (
                    vertex.is_instance_normalization_reshape(),
                    vertex.is_group_norm_reshape()[0],
                )
                if vertex.op == "DepthToSpace":
                    self._create_depth_to_space_layer(vertex)
                elif vertex.is_reshape_expand_resize_nearest():
                    consumed_vertices = self._create_resize_layer(vertex)
                elif vertex.is_inner_product_matmul():
                    consumed_vertices = self._create_convolutional_layer(vertex, is_inner_product=True)
                elif vertex.is_shuffle():
                    consumed_vertices = self._create_shuffle_layer(vertex)
                elif vertex.is_flattened_global_avgpool() or is_flattened_global_maxpool:
                    consumed_vertices = self._create_pooling_layer(vertex)
                elif vertex.is_dilated_conv():
                    consumed_vertices = self._create_convolutional_layer(vertex)
                elif vertex.is_space_to_depth_transpose_reshape()[0] or vertex.is_depth_to_space_reshape_transpose()[0]:
                    consumed_vertices = self._handle_s2d_and_d2s_reshape_transpose(vertex)
                elif vertex.is_height_to_features_reshape():
                    consumed_vertices = self._create_convolutional_layer(vertex)
                elif vertex.is_gcn_block_transpose() or vertex.is_hc_transpose():
                    consumed_vertices = self._create_transpose_layer(vertex)
                elif is_group_norm or is_instance_norm:
                    consumed_vertices = self._create_layer_normalization_layer(
                        vertex, group_norm=is_group_norm, instance_norm=is_instance_norm
                    )
                elif vertex.op in TILE_OPS or vertex.is_unsqueeze_tile():
                    consumed_vertices = self._create_tile_layer(vertex)
                elif vertex.is_flatten_width_over_features_reshape():
                    consumed_vertices = self._create_space_to_depth_layer(vertex)
                elif (
                    vertex.is_f_to_w_transpose_reshape()
                    or vertex.is_flat_to_frames_reshape()
                    or vertex.is_width_features_transpose()
                    or vertex.is_reshape_before_einsum()
                    or vertex.is_spatial_flatten_reshape()
                    or vertex.is_spatial_unflatten()
                    or vertex.is_height_width_transpose()
                    or vertex.is_spatial_flatten_features_to_width()
                    or vertex.is_input_to_attention_windows_reshape()
                    or vertex.is_attention_windows_to_input_reshape()
                    or vertex.is_groups_to_spatial_flatten()
                    or vertex.is_spatial_flatten_to_groups()
                    or vertex.is_spatial_flatten_and_groups_to_features()
                    or vertex.is_partial_groups_to_spatial_flatten()
                    or vertex.is_features_to_stack()
                    or vertex.is_flatten_height_stack_reshape()
                ):
                    consumed_vertices, should_assign_vertex_to_layer = self._create_format_conversion_layer(vertex)
                elif vertex.is_unsqueeze_resize_nearest():
                    consumed_vertices = self._create_resize_layer(vertex)
                else:
                    msg = f"Failed to determine type of layer to create in node {vertex.name}"
                    raise UnsupportedShuffleLayerError(
                        msg,
                    )
            elif vertex.op in RESIZE_OPS or (vertex.op in TILE_OPS and vertex.is_torch_tile_resize_nearest()):
                consumed_vertices = self._create_resize_layer(vertex)
            elif (
                vertex.op in SLICE_OPS
                or (is_regular_gather_slice and not is_grouped_reduce_sum_gather)
                or (is_channel_shuffle_gather_slice)
            ):
                if vertex.is_space_to_depth():
                    consumed_vertices = self._create_space_to_depth_layer(vertex)
                elif vertex.is_reversed_argmax_slice():
                    consumed_vertices = self._create_logits_layer(vertex, is_reversed_argmax=True)
                else:
                    consumed_vertices = self._create_slice_layer(
                        vertex, is_channel_shuffle=is_channel_shuffle_gather_slice
                    )
            elif vertex.op in GATHER_OPS:
                if vertex.is_channels_gather_to_conv():
                    consumed_vertices = self._create_convolutional_layer(vertex)
                elif is_grouped_reduce_sum_gather:
                    consumed_vertices = self._create_grouped_reduce_sum_layer(vertex)
            elif vertex.op in REDUCE_MAX_OPS:
                if vertex.is_torch_tile_reduce_max():
                    self._create_pooling_layer(vertex)
                else:
                    self._create_reduce_max_layer(vertex)
            elif vertex.op in REDUCE_MIN_OPS:
                self._create_reduce_min_layer(vertex)
            elif vertex.op in PAD_OPS:
                if vertex.is_grouped_channels_pad():
                    consumed_vertices = self._create_convolutional_layer(vertex)
                else:
                    consumed_vertices = self._create_external_pad_layer(vertex)
            elif vertex.op in REDUCE_SUM_OPS:
                self._create_reduce_sum_layer(vertex)
            elif vertex.op in REDUCE_L2_OPS:
                if vertex.is_l2_normalization():
                    consumed_vertices = self._create_l2_normalization_layer(vertex)
                else:
                    self._create_reduce_l2_layer(vertex)
            elif vertex.op in REDUCE_SUM_SQUARE_OPS:
                self._create_reduce_sum_square_layer(vertex)
            elif vertex.op in DIV_OPS and vertex.is_ew_div():
                self._create_ew_div_layer(vertex)
            elif vertex.op in INSTANCE_NORMALIZATION_OPS:
                consumed_vertices = self._create_layer_normalization_layer(vertex, instance_norm=True)
            elif vertex.op in EINSUM_OPS and is_einsum_conv1x1:
                consumed_vertices = self._create_einsum_conv1x1(vertex)
            elif (
                vertex.op == "Unsqueeze"
                and vertex.is_successive_unsqueeze_flat_to_frame()
                or vertex.op == "Squeeze"
                and vertex.is_spatial_flatten_reshape()
            ):
                consumed_vertices, should_assign_vertex_to_layer = self._create_format_conversion_layer(vertex)
            elif vertex.op in EQUAL_OPS:
                consumed_vertices = self._create_equal_layer(vertex)
            elif vertex.op in RNN_OPS:
                consumed_vertices = self._create_rnn_layer(vertex)
            elif vertex.op in LSTM_OPS:
                consumed_vertices = self._create_lstm_layer(vertex)
            elif vertex.op in GRU_OPS:
                consumed_vertices = self._create_gru_layer(vertex)
            elif vertex.op in ONE_HOT_OPS and vertex.is_supported_one_hot():
                consumed_vertices = self._create_one_hot_layer(vertex)
            elif vertex.op in LOG_SOFTMAX_OPS:
                consumed_vertices = self._create_log_softmax_layer(vertex)
            elif vertex.op in SCATTER_ND_OPS:
                consumed_vertices = self._create_scatter_nd_layer(vertex)
            else:
                msg = f"Unexpected node {vertex.name} ({vertex.op})"
                raise UnexpectedNodeError(msg)
        except (RecordableParserError, RecordableCreateLayerError) as e:
            self._handle_recordable_parser_error(vertex, e)
            return

        consumed_vertices.append(vertex)
        self._handle_consumed_vertices(consumed_vertices, should_assign_vertex_to_layer)

    def _handle_consumed_vertices(self, consumed_vertices, should_assign_vertex_to_layer=True):
        super()._handle_consumed_vertices(consumed_vertices, should_assign_vertex_to_layer)

        # maintain updated shapes if op is affected by batch-first shape transformation (multi-head attention)
        for vertex in consumed_vertices:
            # handle edge case where successor is an identity
            if vertex in self._vertices_to_layers:
                layer = self._vertices_to_layers[vertex]
                for succ in self.graph.successors(vertex):
                    if succ.op == "Identity":
                        layer.add_original_name(succ.name)
                        self._vertices_to_layers[succ] = layer
                        self._visited_states[succ] = VertexState.CONSUMED

    def _consume_flatten_chain(self, pred, layer):
        if pred.op == "Flatten":
            for flatten in pred.get_flatten_chain():
                self._consume_pre_layer_op(flatten, layer)

    def _handle_s2d_and_d2s_reshape_transpose(self, vertex):
        is_d2s, consumed_vertices = vertex.is_depth_to_space_reshape_transpose()
        if not is_d2s:
            _, consumed_vertices = vertex.is_space_to_depth_transpose_reshape()
        self._current_layer = self._get_pred_layer(vertex)
        self._add_original_names(self._current_layer, [vertex, *consumed_vertices])
        return consumed_vertices

    @staticmethod
    def _add_original_names(layer, vertices, reverse_insertion=False):
        if vertices is not None:
            for vertex in vertices:
                if vertex.op not in ["Constant"]:
                    layer.add_original_name(vertex.name, reverse_insertion=reverse_insertion)

    def _should_skip_vertex(self, vertex):
        is_silu, _ = vertex.is_silu_activation()
        is_gelu, _ = vertex.is_gelu_activation()
        is_mish, _ = vertex.is_mish_activation()
        is_softsign, _ = vertex.is_softsign_activation()
        is_swish = vertex.is_swish_activation_ew_mul()
        is_hardswish, _ = vertex.is_hardswish_activation()
        is_simple_hardswish, _ = vertex.is_simple_hardswish_activation()
        is_keras_resize = vertex.op == "Unsqueeze" and vertex.is_unsqueeze_resize_nearest()

        if vertex.op in CONCAT_OPS and (vertex.is_upsample_concat()):
            return True

        if vertex.op in SLICE_OPS and (vertex.is_resize_slice() or vertex.is_empty_slice()):
            return True

        if vertex.op in MUL_OPS:
            return (
                (not vertex.is_ew_mult() and not vertex.is_normalization() and not vertex.is_square_mul())
                or is_silu
                or is_gelu
                or is_mish
                or is_hardswish
                or is_simple_hardswish
                or is_swish
                or vertex.is_layer_norm()
            )

        if vertex.op in DIV_OPS and (vertex.is_upsample_div() or is_softsign or vertex.is_l2_norm_div()):
            return True

        if vertex.op in SKIP_OPS and not is_keras_resize:
            return True

        return bool(vertex.op in NEG_OPS and vertex.is_prelu_activation())

    def _get_pred_layer(self, vertex):
        preds = list(self.graph.predecessors(vertex))
        if preds and preds[0] in self._vertices_to_layers:
            return self._vertices_to_layers[preds[0]]
        return self._current_layer

    def _separate_external_paddings(self):
        layers_to_add = {}
        for layer in list(self._layers_graph):
            if hasattr(layer, "padding") and hasattr(layer, "external_padding_value"):
                zero_pads = layer.external_padding_value and layer.external_padding_value == [0, 0, 0, 0, 0, 0]
                external_undecided = layer.padding == TemporaryPaddingType.external_undecided
                same_lower = layer.padding == TemporaryPaddingType.same_lower
                if external_undecided and zero_pads:
                    layer.padding = PaddingType.valid
                elif external_undecided or same_lower or layer.padding == TemporaryPaddingType.conv3d:
                    external_padding = layer.external_padding_value
                    if same_lower:
                        dilations = layer.dilations if hasattr(layer, "dilations") else [1, 1, 1, 1]

                        # same_lower padding scheme is opposite to same_tensorflow
                        end_h, begin_h, end_w, begin_w = calculate_padding(
                            PaddingType.same_tensorflow,
                            layer.kernel_height,
                            layer.kernel_width,
                            layer.stride_height,
                            layer.stride_width,
                            layer.input_height,
                            layer.input_width,
                            dilations,
                        )
                        external_padding = [begin_h, end_h, begin_w, end_w]
                    if layer.padding == TemporaryPaddingType.conv3d:
                        # the disparity padding is increasing the padding on the features by f_in/disparity
                        f_in = layer.input_shape[-1] // layer.input_disparity
                        external_padding = [
                            *external_padding[:4],
                            external_padding[4] * f_in,
                            external_padding[5] * f_in,
                        ]
                        layer.input_disparity += sum(layer.external_padding_value[4:])
                    layer.padding = PaddingType.valid
                    pad_layer = ExternalPadLayer.create(None, layer.input_vertex_order, padding_vals=external_padding)
                    if layer.op == LayerType.maxpool:
                        # -inf is assigned to the padding layer's constant value
                        pad_layer.padding_const_value = layer.padding_const_value
                    for original_name in layer.original_names:
                        pad_layer.add_original_name(original_name)

                    layers_to_add[pad_layer] = layer
                    layer.external_padding_value = None

        self._layers_graph.insert_layers(layers_to_add)

    def _is_model_without_params(self):
        for input_vertex in self.graph.net_input:
            for succ in self.graph.successors(input_vertex):
                input_index = list(succ._info.input).index(input_vertex.name)
                if succ.op in CONV2D_OPS:
                    if CONV2D_INPUT_ORDER[input_index] != "X":
                        return True
                elif succ.op in BN_OPS and BN_INPUT_ORDER[input_index] != "X":
                    return True
        return False

    def _validate_model_params(self):
        if self._is_model_without_params():
            raise NoParamsModelError(
                "The weights in the model are considered as inputs, did you export the model with export_params=False?",
            )

    def _has_bn_ops_in_training(self):
        for node in self.graph.vertices_by_name.values():
            # from the BatchNormalization op documenataion:
            # training_mode == 0 (False): single output (output)
            # training_mode == 1 (True): multiple outputs (output, running_mean, running_var)
            if node.op in BN_OPS and len(list(node._info.output)) > 1:
                return True
        return False

    def _create_null_layer(self, vertex):
        layer = NullLayer.create(vertex.name, vertex.input, vertex.get_output_shapes(validate_zero_dims=True))
        consumed_vertices, reverse_insertion = vertex.get_null_vertices()
        self._add_original_names(layer, consumed_vertices, reverse_insertion)
        self._add_layer(layer)
        return consumed_vertices

    def _create_convolutional_layer(self, vertex, is_inner_product=False):
        consumed_vertices = []
        bias, consumed_bias_nodes = vertex.get_bias()
        consumed_vertices.extend(consumed_bias_nodes)
        vertex_kernel, consumed_kernel_nodes = vertex.get_kernel()
        consumed_vertices.extend(consumed_kernel_nodes)
        padding, pads, _, consumed_padding_nodes = vertex.get_vertex_padding()
        consumed_vertices.extend(consumed_padding_nodes)
        output_shapes = vertex.get_output_shapes()
        dynamic_kernel_shape = None
        op = LayerType.base_conv

        if vertex.op in ["Reshape", "Transpose", *DENSE_OPS]:
            groups = 1
            strides, dilations = [1, 1, 1, 1], [1, 1, 1, 1]
        else:
            # incase there is a groups in input format and the conv is grouped conv, the groups should be multiplied
            groups_idx = (
                vertex.input_format.index(Dims.GROUPS)
                if vertex.input_format and Dims.GROUPS in vertex.input_format and vertex.get_groups() > 1
                else None
            )
            groups = vertex.get_input_shapes(convert_to_nhwc=False)[0][groups_idx] if groups_idx is not None else 1
            groups = vertex.get_groups() * groups
            strides = vertex.get_strides()
            dilations = vertex.get_dilations()

        if vertex.op == "Conv":
            if vertex.is_conv_over_groups():
                kernel, bias = vertex.get_conv_over_groups_info()
            elif vertex_kernel is None:
                kernel = None
                dynamic_kernel_shape = vertex.get_dynamic_kernel_shape()
                if not dynamic_kernel_shape:
                    raise UnsupportedModelError(f"Cannot find kernel in vertex {vertex.name}")
                if dynamic_kernel_shape[2] == 1 and groups == dynamic_kernel_shape[3]:
                    groups = 1
                    op = LayerType.base_dw
                    dynamic_kernel_shape = [
                        dynamic_kernel_shape[0],
                        dynamic_kernel_shape[1],
                        dynamic_kernel_shape[3],
                        dynamic_kernel_shape[2],
                    ]
            else:
                kernel = np.transpose(vertex_kernel, [2, 3, 1, 0])  # [k_w, k_h, f_in, f_out] (onnx repr)
                if kernel.shape[2] == 1 and groups == kernel.shape[3]:
                    groups = 1
                    op = LayerType.base_dw
                    kernel = np.transpose(vertex_kernel, [2, 3, 0, 1])  # [k_w, k_h, f_in, f_out] (onnx repr)

        elif vertex.op == "ConvTranspose":
            # TODO: validate padding and strides
            op = LayerType.base_deconv
            padding = PaddingType.deconv
            kernel = np.transpose(vertex_kernel, [2, 3, 0, 1])  # [k_w, k_h, f_in, f_out] (onnx repr)
            if groups > 1:
                # kernel manipulation to accommodate sdk group deconv implementation
                h, w, f_in, f_out = kernel.shape
                kernel = np.reshape(kernel, [h, w, groups, int(f_in / groups), f_out])
                kernel = np.transpose(kernel, [0, 1, 3, 2, 4])
                kernel = np.reshape(kernel, [h, w, int(f_in / groups), f_out * groups])

        elif vertex.op in DENSE_OPS:
            if vertex.is_matmul_over_groups():
                kernel, bias = vertex.get_conv_over_groups_info()
            else:
                if vertex.should_transpose_kernel():
                    vertex_kernel = np.transpose(vertex_kernel)
                kernel = np.reshape(vertex_kernel, [1, 1, *vertex_kernel.shape])
            if output_shapes and len(output_shapes[0]) == 2:
                output_shapes = [[output_shape[0], 1, 1, output_shape[1]] for output_shape in output_shapes]
            is_mha_dense, mha_consumed_ops = vertex.is_multi_head_attention_dense()
            if is_mha_dense:
                consumed_vertices.extend(mha_consumed_ops)

        elif vertex.op == "Transpose":
            vertex_kernel, dilations, strides, dilated_conv_consumed_vertices = vertex.get_dilated_conv_info()
            kernel = np.transpose(vertex_kernel, [2, 3, 1, 0])
            output_shapes = dilated_conv_consumed_vertices[-1].get_output_shapes()
            consumed_vertices.extend(dilated_conv_consumed_vertices)

        elif vertex.op == "Reshape":
            if is_inner_product:
                # reshape before inner product matmul is converted to identity dummy conv
                reshape_consumed_vertices = []
                padding = PaddingType.valid
                input_features = vertex.get_input_shapes()[0][3]
                kernel_shape = [1, 1, input_features, input_features]
                bias = np.zeros(input_features, dtype=np.float32)
                kernel = np.reshape(np.identity(input_features, dtype=np.float32), kernel_shape)
            else:
                kernel, bias, strides, reshape_consumed_vertices = vertex.get_height_to_features_conv_info()
                padding = PaddingType.valid
                output_shapes = reshape_consumed_vertices[0].get_output_shapes()
            consumed_vertices.extend(reshape_consumed_vertices)

        elif vertex.op in GATHER_OPS:
            kernel = vertex.get_channels_gather_to_conv_kernel()

        elif vertex.op in PAD_OPS:
            kernel, bias, groups, pad_consumed_vertices = vertex.get_grouped_pad_to_conv_info()
            consumed_vertices.extend(pad_consumed_vertices)
            padding = PaddingType.valid

        layer = Conv2DLayer.create(
            vertex.name,
            vertex.input,
            op,
            kernel,
            bias,
            padding,
            pads,
            strides,
            dilations,
            groups,
            output_shapes,
            dynamic_kernel_shape,
        )

        self._add_original_names(layer, consumed_vertices)
        self._add_layer(layer)
        return consumed_vertices

    def _create_conv3d_layer(self, vertex):
        info, consumed_vertices = vertex.get_conv3d_info()
        layer = Conv2DLayer.create(
            vertex.name,
            vertex.input,
            LayerType.base_conv,
            info.kernel,
            info.bias,
            padding=info.padding,
            padding_vals=info.pads_val,
            strides=info.strides,
            dilations=info.dilations,
            groups=info.groups,
            output_shapes=info.output_shapes,
            dynamic_kernel_shape=None,
            input_disparity=info.input_disparity,
        )

        self._add_original_names(layer, consumed_vertices)
        self._add_layer(layer)
        return consumed_vertices

    def _create_batch_norm_layer(self, vertex, is_transposed_batch_norm):
        if is_transposed_batch_norm:
            bn_info, consumed_vertices = vertex.get_transposed_bn_info()
        else:
            bn_info, consumed_vertices = vertex.get_bn_info()

        layer = BatchNormLayer.create(vertex.name, vertex.input, bn_info)
        self._add_layer(layer)

        if is_transposed_batch_norm:
            for vertex in consumed_vertices:
                if vertex.op != "Transpose":
                    layer.add_original_name(vertex.name)

        return consumed_vertices

    def _create_layer_normalization_layer(self, vertex, group_norm=False, instance_norm=False):
        if group_norm:
            layer_info, rms_norm, consumed_vertices = vertex.get_group_norm_info()
        elif instance_norm:
            layer_info, consumed_vertices, rms_norm = *vertex.get_instance_normalization_info(), False
        else:
            layer_info, rms_norm, consumed_vertices = vertex.get_layer_normalization_info()

        layer = LayerNormalizationLayer.create(vertex.name, vertex.input, layer_info, rms_norm=rms_norm)
        self._add_original_names(layer, consumed_vertices)
        self._add_layer(layer)
        return consumed_vertices

    def _create_l2_normalization_layer(self, vertex):
        axis, scale, consumed_vertices = vertex.get_l2_normalization_info()
        layer = L2NormalizationLayer.create(vertex.name, vertex.input, axis, scale)
        self._add_original_names(layer, consumed_vertices)
        self._add_layer(layer)
        return consumed_vertices

    def _create_dense_layer(self, vertex):
        consumed_vertices = []
        bias, consumed_bias_nodes = vertex.get_bias()
        consumed_vertices.extend(consumed_bias_nodes)
        kernel, consumed_kernel_nodes = vertex.get_kernel()
        consumed_vertices.extend(consumed_kernel_nodes)
        should_transpose_kernel = vertex.should_transpose_kernel()
        # converts rank 2 output shapes to rank 4
        rank4_output_shapes = [
            shape if len(shape) == 4 else [shape[0], 1, 1, shape[1]] for shape in vertex.get_output_shapes()
        ]
        layer = DenseLayer.create(
            vertex.name,
            vertex.input,
            bias,
            kernel,
            should_transpose_kernel=should_transpose_kernel,
            output_shapes=rank4_output_shapes,
        )
        self._add_original_names(layer, consumed_bias_nodes)
        self._add_layer(layer)
        return consumed_vertices

    def _create_pooling_layer(self, vertex):
        count_include_pad = True
        if vertex.op in ["AveragePool", "GlobalAveragePool", "Mean", "ReduceMean", "Reshape"]:
            op = LayerType.avgpool
            count_include_pad = vertex.get_avgpool_count_include_pad()
        else:
            op = LayerType.maxpool

        consumed_vertices = []
        should_add_original_names = False
        if vertex.op in ["AveragePool", "MaxPool"]:
            padding, pads, _, consumed_vertices = vertex.get_vertex_padding()
            strides = vertex.get_strides()
            should_set_kernel_to_input_shape = False
            dims = vertex.get_kernel_shape()
            kernel_shape = [1, 1, dims[0], 1] if len(dims) == 1 else [1, dims[-2], dims[-1], 1]
        elif vertex.is_avgpool_reduce_mean():
            pads = None
            padding = PaddingType.valid
            should_set_kernel_to_input_shape = False
            kernel_shape, strides = vertex.get_avgpool_reduce_mean_info()
        elif vertex.op in [
            "GlobalAveragePool",
            "GlobalMaxPool",
            "Mean",
            "ReduceMean",
            "Reshape",
            "ReduceMax",
            "Flatten",
        ]:
            if vertex.op == "ReduceMean":
                vertex.validate_reduce_mean_as_pooling_layer()
            kernel_shape = None
            strides = None
            pads = None
            padding = PaddingType.valid
            should_set_kernel_to_input_shape = True

        output_shapes = vertex.get_output_shapes()
        if vertex.op in ["Reshape", "Flatten"]:
            should_add_original_names = True
            output_shapes, consumed_vertices = vertex.get_flattened_pooling_info()

        layer = PoolingLayer.create(
            vertex.name,
            vertex.input,
            op,
            kernel_shape,
            strides,
            padding,
            pads,
            should_set_kernel_to_input_shape,
            output_shapes=output_shapes,
            count_include_pad=count_include_pad,
            ceil_mode=vertex.get_ceil_mode(op),
        )

        self._add_layer(layer)
        if should_add_original_names:
            self._add_original_names(layer, consumed_vertices)
        return consumed_vertices

    def _create_einsum_conv1x1(self, vertex):
        consumed_vertices = []
        kernel, equation = vertex.get_einsum_info()
        layer = EinsumLayer.create(vertex.name, vertex.input, equation, kernel, vertex.get_output_shapes())
        self._add_layer(layer)
        self._add_original_names(layer, consumed_vertices)
        return consumed_vertices

    def _create_activation_layer(self, vertex):
        activation = None
        consumed_vertices = []
        reverse_insertion = False

        leaky_alpha = None
        delta_bias = None
        activation_threshold = None
        prelu_slope = None
        swish_beta = None
        activation_values = None
        hardsigmoid_alpha = None
        hardsigmoid_beta = None
        clip_min = None
        clip_max = None
        pow_exponent = None

        if vertex.op == "PRelu" or vertex.is_prelu_activation():
            vertex_prelu_slope, consumed_vertices = vertex.get_prelu_slope()
            if np.any(vertex_prelu_slope != vertex_prelu_slope[0]):
                activation = ActivationType.prelu
                prelu_slope = vertex_prelu_slope
            else:
                activation = ActivationType.leaky
                leaky_alpha = vertex_prelu_slope[0]
        elif vertex.op == "Relu":
            activation = ActivationType.relu
        elif vertex.op == "Elu":
            activation = ActivationType.elu
        elif vertex.op == "Sigmoid":
            is_silu, consumed_vertices = vertex.is_silu_activation()
            activation = ActivationType.silu if is_silu else ActivationType.sigmoid
        elif vertex.op == "Exp":
            activation = ActivationType.exp
        elif vertex.op == "LeakyRelu":
            vertex_leaky_alpha = vertex.get_leaky_alpha()
            if vertex_leaky_alpha < 0:
                prelu_slope = [vertex_leaky_alpha]
                activation = ActivationType.prelu
            else:
                leaky_alpha = vertex_leaky_alpha
                activation = ActivationType.leaky
        elif vertex.op == "Tanh":
            activation = ActivationType.tanh
        elif vertex.op in ["Sign", "Abs", "Softsign"]:
            is_biased_delta, delta_bias, biased_delta_consumed_vertices = vertex.is_biased_delta_activation()
            is_softsign, softsign_consumed_vertices = vertex.is_softsign_activation()
            if is_biased_delta:
                activation = ActivationType.biased_delta
                consumed_vertices = biased_delta_consumed_vertices
            elif is_softsign:
                activation = ActivationType.softsign
                consumed_vertices = softsign_consumed_vertices
            elif vertex.op == "Abs":
                prelu_slope = [-1.0]
                activation = ActivationType.prelu
        elif vertex.op == "Greater":
            is_threshold_activation, activation_threshold, consumed_vertices = vertex.is_threshold_activation()
            if is_threshold_activation:
                activation = ActivationType.threshold
            else:
                activation, activation_values, consumed_vertices = vertex.get_activation_less_or_greater_values()
        elif vertex.op == "Softplus":
            is_mish, consumed_vertices = vertex.is_mish_activation()
            activation = ActivationType.mish if is_mish else ActivationType.softplus
        elif vertex.op == "Div":
            is_gelu_activation, consumed_vertices = vertex.is_gelu_activation()
            if is_gelu_activation:
                activation = ActivationType.gelu
            else:
                is_inv_pos_activation, consumed_vertices = vertex.is_inv_pos_activation()
                if is_inv_pos_activation:
                    activation = ActivationType.inv_pos
        elif vertex.op == "Reciprocal":
            activation = ActivationType.inv_pos
        elif vertex.op in ["Add", "HardSwish"] or vertex.is_hardsigmoid():
            is_simple_hardswish_activation = False
            is_hardswish_activation, consumed_vertices = vertex.is_hardswish_activation()
            if vertex.is_hardsigmoid():
                is_simple_hardswish_activation, consumed_vertices = vertex.is_simple_hardswish_activation()
                if not is_simple_hardswish_activation:
                    activation = ActivationType.hardsigmoid
                    hardsigmoid_alpha, hardsigmoid_beta, consumed_vertices = vertex.get_hardsigmoid_info()
            if is_hardswish_activation or is_simple_hardswish_activation:
                activation = ActivationType.hardswish
        elif vertex.op == "Mul":
            is_swish_activation, swish_beta, consumed_vertices = vertex.is_swish_activation()
            if is_swish_activation:
                activation = ActivationType.swish
        elif vertex.op == "Sqrt":
            activation = ActivationType.sqrt
        elif vertex.op == "Less":
            activation, activation_values, consumed_vertices = vertex.get_activation_less_or_greater_values()
        elif vertex.op == "Log":
            activation = ActivationType.log
        elif vertex.op in ("Min", "Max", "Clip"):
            activation = ActivationType.clip
            clip_min, clip_max, consumed_vertices = vertex.get_min_max_clip_info()
            reverse_insertion = vertex.op == "Clip"  # clip op requires inverted original names due to cast preds
        elif vertex.op == "Pow":
            is_fraction_pow, pow_exponent, consumed_vertices = vertex.is_decimal_fraction_pow_activation()
            if is_fraction_pow:
                activation = ActivationType.pow

        if activation is None:
            raise UnsupportedActivationLayerError(f"Unexpected activation at {vertex.name}, op={vertex.op}")

        layer = ActivationLayer.create(
            vertex.name,
            vertex.input,
            activation,
            leaky_alpha,
            delta_bias=delta_bias,
            activation_threshold=activation_threshold,
            output_shapes=vertex.get_output_shapes(),
            prelu_slope=prelu_slope,
            swish_beta=swish_beta,
            activation_values=activation_values,
            hardsigmoid_alpha=hardsigmoid_alpha,
            hardsigmoid_beta=hardsigmoid_beta,
            clip_min=clip_min,
            clip_max=clip_max,
            pow_exponent=pow_exponent,
        )

        self._add_layer(layer)
        self._add_original_names(layer, consumed_vertices, reverse_insertion)
        return consumed_vertices

    def _create_ew_add_layer(self, vertex):
        self._create_ew_layer(vertex, EWAddLayer)

    def _create_ew_sub_layer(self, vertex):
        self._create_ew_layer(vertex, EWSubLayer)

    def _create_ew_mult_layer(self, vertex):
        self._create_ew_layer(vertex, EWMultLayer)

    def _create_ew_div_layer(self, vertex):
        self._create_ew_layer(vertex, EWDivLayer)

    def _create_ew_max_layer(self, vertex):
        self._create_ew_layer(vertex, EWMaxLayer)

    def _create_ew_min_layer(self, vertex):
        self._create_ew_layer(vertex, EWMinLayer)

    def _create_ew_layer(self, vertex, layer_cls):
        ew_op_input = vertex.input
        is_ew_op_with_const_input = vertex.is_ew_op_with_const_input()

        const_layer = None
        if is_ew_op_with_const_input:
            const_layer, ew_op_input = self._create_ew_const_input_layer(vertex)

        layer = layer_cls.create(
            vertex.name,
            ew_op_input,
            output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
        )
        self._add_layer(layer)
        if is_ew_op_with_const_input and const_layer is not None:
            self._layers_graph.add_edge(const_layer, layer)
            layer.add_input_by_vertex(const_layer, input_name=const_layer.original_names[0])

    def _create_ew_const_input_layer(self, vertex):
        ew_op_input = vertex.get_const_layer_input_order()
        input_values = vertex.get_const_input_values()
        input_values_orig_shape = input_values.shape
        ew_op_output_shape = vertex.get_output_shapes()[0]
        onnx_ew_op_output_shape = vertex.get_output_shapes(convert_to_nhwc=False)[0]

        if len(input_values_orig_shape) == 1 and vertex.output_format:
            input_values_expended_shape = [1, 1, 1]
            if Dims.HEIGHT in vertex.output_format:
                input_values_expended_shape[0] = input_values_orig_shape[0]
            else:
                input_values_expended_shape[1] = input_values_orig_shape[0]
            input_values = np.reshape(input_values, input_values_expended_shape)
        elif len(input_values_orig_shape) == 2:
            if len(onnx_ew_op_output_shape) == 3:
                if vertex.output_format == [Dims.BATCH, Dims.CHANNELS, Dims.WIDTH]:
                    input_values = np.expand_dims(np.transpose(input_values, [1, 0]), axis=0)
                else:
                    const_output_shape = [-1, ew_op_output_shape[1], ew_op_output_shape[3], ew_op_output_shape[2]]
                    input_values = np.reshape(input_values, [1, *ew_op_output_shape[2:]])
            else:
                const_output_shape = [-1, *ew_op_output_shape[1:4]]
                input_values = np.reshape(input_values, [input_values.shape[0], input_values.shape[1], 1])
        elif len(input_values_orig_shape) == 3:
            if vertex.output_format == [Dims.BATCH, Dims.WIDTH, Dims.CHANNELS] or (
                vertex.output_format == [Dims.GROUPS, Dims.WIDTH, Dims.CHANNELS] and input_values_orig_shape[0] == 1
            ):
                pass
            elif vertex.output_format in [
                [Dims.WIDTH, Dims.BATCH, Dims.CHANNELS],
                [Dims.WIDTH, Dims.GROUPS, Dims.CHANNELS],
            ]:
                const_output_shape = [-1, ew_op_output_shape[2], ew_op_output_shape[1], ew_op_output_shape[3]]
                input_values = np.transpose(input_values, [1, 0, 2])
            elif vertex.output_format == [Dims.BATCH, Dims.CHANNELS, Dims.WIDTH]:
                input_values = np.transpose(input_values, [0, 2, 1])
            elif vertex.output_format == [Dims.BATCH, Dims.GROUPS, Dims.WIDTH, Dims.CHANNELS]:
                const_output_shape = [-1, *ew_op_output_shape[1:4]]
                input_values = np.transpose(input_values, [1, 0, 2])
                input_values = np.reshape(
                    input_values,
                    [1, input_values.shape[0], input_values.shape[1] * input_values.shape[2]],
                )
            elif vertex.output_format == [Dims.BATCH, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH]:
                input_values = np.transpose(input_values, [1, 2, 0])
            elif vertex.output_format == [Dims.GROUPS, Dims.WIDTH, Dims.CHANNELS]:
                input_values = np.transpose(input_values, [1, 0, 2])
                # flatten the groups dimension over the channels dimension
                input_values = np.reshape(
                    input_values,
                    [1, input_values.shape[0], input_values.shape[1] * input_values.shape[2]],
                )
            elif vertex.output_format in [
                [Dims.GROUPS, Dims.CHANNELS, Dims.WIDTH],
                [Dims.BATCH, Dims.GROUPS, Dims.CHANNELS, Dims.WIDTH],
                [Dims.HEIGHT, Dims.GROUPS, Dims.WIDTH, Dims.CHANNELS],
            ]:
                input_values = np.transpose(input_values, [2, 0, 1])
                # reshape to [1, width, groups * channels]
                input_values = np.reshape(
                    input_values,
                    [1, input_values.shape[0], input_values.shape[1] * input_values.shape[2]],
                )
            else:
                const_output_shape = [-1, ew_op_output_shape[1], ew_op_output_shape[3], ew_op_output_shape[2]]
        elif len(input_values_orig_shape) == 4:
            const_output_shape = (
                [-1, *ew_op_output_shape[1:4]]
                if vertex.output_format is None or vertex.output_format[0] == Dims.BATCH
                else ew_op_output_shape
            )
            if vertex.output_format and vertex.output_format[1:] == [Dims.WIDTH, Dims.GROUPS, Dims.CHANNELS]:
                input_values = np.squeeze(input_values, axis=0)
                input_values = np.reshape(
                    input_values,
                    [input_values.shape[0], input_values.shape[1] * input_values.shape[2]],
                )
                input_values = np.expand_dims(input_values, axis=0)
            elif vertex.output_format and vertex.output_format[1:] == [Dims.GROUPS, Dims.WIDTH, Dims.CHANNELS]:
                if input_values.shape[1] != 1:
                    input_values = np.transpose(input_values, [0, 2, 1, 3])
                    input_values = np.reshape(
                        input_values,
                        [input_values_orig_shape[0], 1, input_values_orig_shape[2], -1],
                    )
                input_values = np.squeeze(input_values, axis=0)
            elif (
                vertex.output_format == [Dims.BATCH, Dims.HEIGHT, Dims.WIDTH, Dims.CHANNELS]
                and len(input_values.shape) == 4
            ):
                # the input values are in NHWC format no need to transpose or reshape
                input_values = input_values[0]
            elif not vertex.output_format or Dims.CHANNELS in vertex.output_format:
                input_values = np.transpose(input_values, [0, 2, 3, 1])  # move to NHWC format
                input_values = np.reshape(input_values, input_values.shape[1:4])
        elif len(input_values_orig_shape) == 5:
            if (
                vertex.output_format == [Dims.BATCH, Dims.CHANNELS, Dims.GROUPS, Dims.HEIGHT, Dims.WIDTH]
                and onnx_ew_op_output_shape[1] == 1
            ):
                const_output_shape = [-1, *ew_op_output_shape[1:4]]
                input_values = np.squeeze(input_values)
                input_values = np.transpose(input_values, [1, 2, 0])  # move to HWD format
            elif (
                vertex.output_format == [Dims.BATCH, Dims.HEIGHT, Dims.GROUPS, Dims.WIDTH, Dims.CHANNELS]
                and onnx_ew_op_output_shape[0] == 1  # batch is 1
                and onnx_ew_op_output_shape[vertex.output_format.index(Dims.GROUPS)] != 1  # there is groups
                and input_values.shape[vertex.output_format.index(Dims.GROUPS)] == 1
            ):
                # input values should be broadcasted over groups
                groups_index = vertex.output_format.index(Dims.GROUPS)
                groups = onnx_ew_op_output_shape[groups_index]
                input_values = np.concatenate([input_values] * groups, axis=groups_index)
                input_values = np.transpose(input_values, [0, 1, 3, 2, 4])
                input_values = np.reshape(input_values, [*input_values.shape[1:3], -1])
                const_output_shape = [-1, *ew_op_output_shape[1:4]]

        else:
            msg = f"Unsupported const input tensor at {vertex.name}, with shape: {input_values_orig_shape}"
            raise UnsupportedConstInputError(msg)

        if vertex.output_format and ew_op_output_shape[0] == 1:
            const_output_shape = [-1, *ew_op_output_shape[1:4]]

        const_layer = ConstInputLayer.create(f"{vertex.name}_input", [const_output_shape], input_values)
        self._add_layer(const_layer, has_edge=False)
        return const_layer, ew_op_input

    def _get_unsqueeze_tile_information(self, vertex):
        succ = next(iter(self._graph.successors(vertex)))
        ratio_to_tile = succ.get_input_shapes()[1][-1]
        strides, dilations = [1, 1, 1, 1], [1, 1, 1, 1]
        kernel = np.concatenate(
            [np.identity(ratio_to_tile * succ.get_input_shapes(False)[0][1]) for _ in range(ratio_to_tile)],
            axis=1,
        )

        kernel = np.reshape(kernel, [1, 1, kernel.shape[0], kernel.shape[1]])
        return Conv2DInfo(
            kernel=kernel,
            bias=None,
            padding=PaddingType.valid,
            pads_val=[0, 0, 0, 0],
            strides=strides,
            dilations=dilations,
            output_shapes=None,
            groups=1,
        )

    def _create_tile_layer(self, vertex):
        consumed_vertices = []
        if vertex.is_unsqueeze_tile():
            info = self._get_unsqueeze_tile_information(vertex)
            layer = Conv2DLayer.create(
                vertex.name,
                vertex.input,
                LayerType.base_conv,
                info.kernel,
                info.bias,
                info.padding,
                info.pads_val,
                info.strides,
                info.dilations,
                info.groups,
                info.output_shapes,
            )

            self._add_original_names(layer, consumed_vertices)
            self._add_layer(layer)
            return consumed_vertices

        repeats, consumed_vertices = vertex.get_tile_repeats()
        if repeats is None:
            msg = "Tile layer with no repeats is not supported"
            raise UnsupportedTileLayerError(msg)
        filtered_repeats = [(idx, elem) for idx, elem in enumerate(repeats) if elem != 1]
        if len(filtered_repeats) > 1:
            raise UnsupportedTileLayerError("Tile layer with multi-axes-tiling is not supported")

        axis, repeats = filtered_repeats[0]
        output_format = vertex.output_format
        if not output_format:
            output_format = [Dims.BATCH, Dims.HEIGHT, Dims.WIDTH, Dims.CHANNELS]
        concat_dim = output_format[axis]
        axis, group_sizes = vertex.get_concat_info_from_output_format(concat_dim, output_format)
        layer = ConcatLayer.create(
            vertex.name,
            vertex.input * repeats,
            output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
            axis=axis,
            group_sizes=group_sizes,
        )
        self._add_layer(layer)
        return consumed_vertices

    def _create_concat_layer(self, vertex):
        # When one of the inputs to the concat node is constant, we have to create a new
        # input layer to feed it, because we don't support this type of concatenation with a constant
        concat_input, const_layers, consumed_vertices = [], [], []
        group_sizes = None
        axis = DEFAULT_CONCAT_AXIS
        var_initializers = self.graph.values_by_vertex_name[vertex.name]
        expand_vertex, new_input_consumed_vertices = vertex.get_new_concat_input()

        for i, inp_key in enumerate(vertex._info.input):
            inp_vertex = self.graph.vertices_by_inp_key.get(inp_key)
            initializer_value = var_initializers.get(inp_key)
            inference_value = self.graph.output_shapes.get(inp_key + "_value")
            inp_name = inp_vertex.name if inp_vertex else inp_key

            if expand_vertex and expand_vertex == inp_vertex:
                # edge case: class token in transformer models
                new_input_output_shapes = [[-1, x[1], x[3], x[2]] for x in expand_vertex.get_output_shapes()]
                const_values = expand_vertex.get_const_input_values()
                const_layer = ConstInputLayer.create(inp_name, new_input_output_shapes, const_values)
                self._add_layer(const_layer, has_edge=False)
                self._vertices_to_layers[expand_vertex] = const_layer
                self._visited_states[expand_vertex] = VertexState.CONSUMED
                const_layers.append(const_layer)
                consumed_vertices.extend(new_input_consumed_vertices)
                axis = ConcatAxis.spatial_w
                concat_input.append(inp_name)
            else:
                if inp_vertex and inp_vertex.op == "Constant":
                    const_values = inp_vertex.parse_raw_data()
                elif initializer_value is not None:
                    const_values = initializer_value
                elif inference_value is not None:
                    const_values = inference_value
                else:
                    if inp_vertex.is_empty_slice():
                        self._visited_states[inp_vertex] = VertexState.CONSUMED
                    else:
                        concat_input.append(inp_name)
                    continue

                # edge case: multiple const inputs with the same name
                if inp_name in concat_input:
                    inp_name += str(i)
                concat_input.append(inp_name)

                const_values_rank = len(const_values.shape)
                if const_values_rank == 4:
                    const_values = np.squeeze(np.transpose(const_values, [0, 2, 3, 1]), axis=0)
                elif vertex.output_format == [Dims.BATCH, Dims.CHANNELS, Dims.WIDTH]:
                    const_values = np.transpose(const_values, [0, 2, 1])
                elif vertex.output_format == [Dims.BATCH, Dims.CHANNELS]:
                    const_values = np.expand_dims(const_values, axis=1)
                elif const_values_rank != 3:
                    msg = f"Unsupported const input tensor at {vertex.name}, with shape: {const_values.shape}"
                    raise UnsupportedConstInputError(msg)

                # edge case: concat with initializer variable has one data input which causes incorrect axis calculation
                data_input_shape = vertex.get_input_shapes()[0]
                const_output_shape = [-1, const_values.shape[0], const_values.shape[1], const_values.shape[2]]
                output_shape = vertex.get_output_shapes()[0]
                if (data_input_shape[1] + const_output_shape[1]) == output_shape[1]:
                    axis = ConcatAxis.spatial_h
                elif const_values_rank == 3 or (
                    len(output_shape) > 2 and (data_input_shape[2] + const_output_shape[2]) == output_shape[2]
                ):
                    axis = ConcatAxis.spatial_w
                else:
                    axis = ConcatAxis.features

                const_layer = ConstInputLayer.create(inp_name, [const_output_shape], const_values)
                self._add_layer(const_layer, has_edge=False)
                const_layers.append(const_layer)

        axis, group_sizes, output_shapes = vertex.get_concat_info(const_layers)
        layer = ConcatLayer.create(vertex.name, concat_input, output_shapes, axis, group_sizes)
        self._add_layer(layer)

        for const_layer in const_layers:
            self._layers_graph.add_edge(const_layer, layer)
            layer.add_input_by_vertex(const_layer, input_name=const_layer.original_names[0])

        return consumed_vertices

    def _create_logits_layer(self, vertex, is_alternative_softmax=False, is_reversed_argmax=False):
        layer = None
        consumed_vertices = []
        if vertex.op == "Softmax":
            groups, axis, additive_mask = vertex.get_softmax_info()
            layer = SoftmaxLayer.create(
                vertex.name,
                vertex.input,
                groups=groups,
                axis=axis,
                additive_mask=additive_mask,
                output_shapes=vertex.get_output_shapes(),
            )
        elif vertex.op in ALTERNATIVE_SOFTMAX_OPS and is_alternative_softmax:
            alternative_softmax_nodes = vertex.get_softmax_nodes()
            consumed_vertices.extend(alternative_softmax_nodes)
            layer = SoftmaxLayer.create(
                vertex.name,
                alternative_softmax_nodes[0].input,
                output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
            )
            self._vertices_with_edges.append(alternative_softmax_nodes[0])
        elif vertex.op in ["ArgMax", "Slice"]:
            argmax_nodes = vertex.get_argmax_info()
            consumed_vertices.extend(argmax_nodes)
            layer = ArgmaxLayer.create(
                vertex.name,
                vertex.input,
                output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
                reverse_order=is_reversed_argmax,
            )

        valid_vertex_op = (
            vertex.op in LOGITS_OPS
            or (vertex.op in ALTERNATIVE_SOFTMAX_OPS and is_alternative_softmax)
            or (vertex.op in SLICE_OPS and is_reversed_argmax)
        )

        if not valid_vertex_op or not layer:
            raise UnsupportedLogitsLayerError(f"Unexpected logits op at {vertex.name}, op={vertex.op}")

        self._add_layer(layer)
        self._add_original_names(layer, consumed_vertices)
        return consumed_vertices

    def _create_feature_split_layer(self, vertex):
        consumed_vertices = []
        split_sizes, output_shapes, groups = vertex.get_feature_split_info()
        layer = FeatureSplitterLayer.create(
            original_name=vertex.name,
            input_vertex_order=vertex.input,
            split_sizes=split_sizes,
            output_shapes=output_shapes,
            groups=groups,
        )

        self._add_layer(layer)
        self._add_original_names(layer, consumed_vertices)
        return consumed_vertices

    def _create_spatial_split_layer(self, vertex):
        consumed_vertices = []
        axis, split_sizes, output_shapes = vertex.get_spatial_split_info()
        if axis == 2 and all(split_size % 8 == 0 for split_size in split_sizes):
            # Creating width splitter for optimal compilation on sizes divisible by 8
            layer = WidthSplitterLayer.create(vertex.name, vertex.input, split_sizes, output_shapes=output_shapes)
        else:
            layer = SpatialSplitterLayer.create(
                vertex.name, vertex.input, split_sizes, axis, output_shapes=output_shapes
            )
        self._add_layer(layer)
        self._add_original_names(layer, consumed_vertices)
        return consumed_vertices

    def _create_format_conversion_layer(self, vertex):
        should_assign_vertex_to_layer = True
        consumed_vertices = []
        format_conversion, layer = [None] * 2
        shapes = vertex.get_output_shapes()
        kwargs = {}

        if vertex.is_width_features_transpose():
            if vertex.is_transpose_after_spatial_flatten() and not vertex.is_rnn_sequence():
                preds = [x for x in self.graph.predecessors(vertex) if x in self._vertices_to_layers]
                if len(preds) == 1:
                    should_assign_vertex_to_layer = False
                    pred_layer = self._vertices_to_layers[preds[0]]
                    pred_layer.add_original_name(vertex.name)
                    self._vertices_to_layers[vertex] = pred_layer
            else:
                format_conversion = FormatConversionType.transpose_width_features

        elif vertex.is_spatial_flatten_reshape():
            format_conversion = FormatConversionType.spatial_reshape
            input_shape = vertex.get_input_shapes()[0]
            is_spatial_flatten_after_group_norm, first_reshape = vertex.is_spatial_flatten_reshape_after_group_norm()
            if is_spatial_flatten_after_group_norm:
                input_shape = first_reshape.get_input_shapes()[0]

            spatial_reshape_sizes = [1, int(np.prod(input_shape[1:-1]))]
            kwargs["spatial_reshape_sizes"] = spatial_reshape_sizes

        elif (
            vertex.is_spatial_unflatten()
            and not vertex.is_f_to_w_transpose_reshape()
            and not vertex.is_attention_windows_to_input_reshape()
            and not vertex.is_input_to_attention_windows_reshape()
        ):
            consumed_vertices, shapes, expand_sizes = vertex.get_spatial_unflatten_reshape_info()
            format_conversion = FormatConversionType.spatial_reshape
            kwargs["spatial_reshape_sizes"] = expand_sizes

        elif vertex.is_reshape_before_einsum():
            groups = vertex.get_reshape_before_einsum_info()
            format_conversion = FormatConversionType.transpose_width_features
            kwargs["groups"] = groups

        elif vertex.is_f_to_w_transpose_reshape():
            shapes, consumed_vertices = vertex.get_f_to_w_transpose_reshape_info()
            format_conversion = FormatConversionType.features_to_width_features

        elif vertex.is_flat_to_frames_reshape():
            shapes, consumed_vertices = vertex.get_flat_to_frames_reshape_info()
            format_conversion = FormatConversionType.flat_to_frames

        elif vertex.is_successive_unsqueeze_flat_to_frame():
            shapes, consumed_vertices = vertex.get_flat_to_frames_successive_unsqueeze_info()
            format_conversion = FormatConversionType.flat_to_frames

        elif vertex.is_height_width_transpose():
            format_conversion = FormatConversionType.transpose_height_width

        elif vertex.is_spatial_flatten_features_to_width():
            format_conversion = FormatConversionType.general_reshape
            shapes, consumed_vertices = vertex.get_spatial_flatten_features_to_width_info()

        elif vertex.is_input_to_attention_windows_reshape():
            format_conversion = FormatConversionType.split_windowed_attention
            shapes, consumed_vertices, kwargs["attention_params"] = vertex.get_input_to_windows_info()

        elif vertex.is_attention_windows_to_input_reshape():
            format_conversion = FormatConversionType.merge_windowed_attention
            shapes, consumed_vertices, kwargs["attention_params"] = vertex.get_windows_to_input_info()

        elif vertex.is_groups_to_spatial_flatten():
            format_conversion = FormatConversionType.groups_to_spatial_flatten
            kwargs["groups"], shapes, consumed_vertices = vertex.get_groups_to_spatial_flatten_info()

        elif vertex.is_spatial_flatten_to_groups():
            format_conversion = FormatConversionType.spatial_flatten_to_groups
            kwargs["groups"], shapes, consumed_vertices = vertex.get_spatial_flatten_to_groups_info()

        elif vertex.is_spatial_flatten_and_groups_to_features():
            format_conversion = FormatConversionType.spatial_flatten_to_groups
            kwargs["groups"], shapes, consumed_vertices = vertex.get_spatial_flatten_and_groups_to_features_info()

        elif vertex.is_partial_groups_to_spatial_flatten():
            format_conversion = FormatConversionType.partial_groups_to_spatial_flatten
            kwargs["groups"], shapes, consumed_vertices = vertex.get_partial_groups_to_spatial_flatten_info()
        elif vertex.is_features_to_stack():
            format_conversion = FormatConversionType.spatial_reshape
            shapes, kwargs["spatial_reshape_sizes"] = vertex.get_spatial_unflatten_features_to_groups_info()
        elif vertex.is_flatten_height_stack_reshape():
            format_conversion = FormatConversionType.spatial_reshape
            shapes, kwargs["spatial_reshape_sizes"] = vertex.get_flatten_height_stack_reshape_info()

        if None not in (shapes, format_conversion):
            layer = FormatConversionLayer.create(vertex.name, vertex.input, format_conversion, shapes, **kwargs)
            self._add_layer(layer)
            self._add_original_names(layer, consumed_vertices)
        elif should_assign_vertex_to_layer:
            raise UnsupportedFormatConversionLayerError(f"Unable to create format conversion layer at {vertex.name}")

        return consumed_vertices, should_assign_vertex_to_layer

    def _create_depth_to_space_layer(self, vertex):
        depth_to_space_type = DepthToSpaceType.dcr if vertex.is_dcr_depth_to_space() else DepthToSpaceType.crd
        block_size = vertex.get_d2s_block_size()
        layer = DepthToSpaceLayer.create(
            vertex.name,
            vertex.input,
            block_size=block_size,
            output_shapes=vertex.get_output_shapes(),
            depth_to_space_type=depth_to_space_type,
        )
        if layer:
            self._add_layer(layer)

    def _create_shuffle_layer(self, vertex):
        if vertex.output_format != [Dims.WIDTH, Dims.CHANNELS]:
            first_reshape, second_reshape, perm, consumed_vertices = vertex.get_shuffle_reshape_transpose_info()

        if vertex.output_format == [Dims.WIDTH, Dims.CHANNELS]:
            # edge case of qwen2_vl_vision, WC: [4x, y] -> [x, 4y]
            in_shape = vertex.get_input_shapes(convert_to_nhwc=False)[0]
            out_shape = vertex.get_output_shapes(convert_to_nhwc=False)[0]
            w_block = in_shape[0] // out_shape[0]
            s2d_type = SpaceToDepthType.classic_dcr
            layer = SpaceToDepthLayer.create(vertex.name, vertex.input, [1, w_block], space_to_depth_type=s2d_type)
            consumed_vertices = []

        # 1. Reshape to rank 5 and transpose perm=[0, 2, 1, 3, 4] -> FeatureShuffle, reshape[1] is num of groups.
        elif len(first_reshape) == 5 and perm == [0, 2, 1, 3, 4]:
            layer = FeatureShuffleLayer.create(vertex.name, vertex.input, groups=first_reshape[1])
        # 2. Reshape to rank 6 and transpose perm=[0, 3, 4, 1, 5, 2] then reshape -> DepthToSpace, reshape[1]=block size
        elif len(first_reshape) == 6 and perm == [0, 3, 4, 1, 5, 2]:
            layer = DepthToSpaceLayer.create(
                vertex.name,
                vertex.input,
                block_size=first_reshape[1],
                first_reshape=first_reshape,
                second_reshape=second_reshape,
            )
        elif len(first_reshape) == 6 and perm == [0, 1, 4, 2, 5, 3]:
            layer = DepthToSpaceLayer.create(
                vertex.name,
                vertex.input,
                first_reshape[2],
                first_reshape=first_reshape,
                second_reshape=second_reshape,
                depth_to_space_type=DepthToSpaceType.crd,
            )
        # 4. Reshape to rank 6 and transpose perm=[0, 1, 3, 5, 2, 4] then reshape -> SpaceToDepth, 1st reshape[-1] = block size
        elif len(first_reshape) == 6 and perm == [0, 1, 3, 5, 2, 4]:
            layer = SpaceToDepthLayer.create(
                vertex.name, vertex.input, [first_reshape[-1]] * 2, space_to_depth_type=SpaceToDepthType.classic_crd
            )
        else:
            raise UnsupportedShuffleLayerError(f"Unable to create shuffle layer at {vertex.name}")

        self._add_layer(layer)
        self._add_original_names(layer, consumed_vertices)

        return consumed_vertices

    def _create_resize_layer(self, vertex):
        layer = None
        consumed_vertices = []
        resize_method = vertex.get_resize_method()
        pixels_mode = ResizeBilinearPixelsMode.disabled

        def split_resize_sizes(output_shape):
            out_vertex = consumed_vertices[-1] if consumed_vertices else vertex
            if out_vertex.output_format:
                dim_to_shape = dict(zip(out_vertex.output_format, output_shape))
                h_sizes = dim_to_shape.get(Dims.HEIGHT, 1.0)
                w_sizes = dim_to_shape.get(Dims.WIDTH, 1.0)
                d_sizes = dim_to_shape.get(Dims.GROUPS, 1.0)
            else:
                h_sizes, w_sizes, d_sizes = output_shape[-2], output_shape[-1], None
                if len(output_shape) == 5:
                    d_sizes = output_shape[-3]

            return h_sizes, w_sizes, d_sizes

        if vertex.op == "Resize" and resize_method == ResizeMethod.bilinear:
            pixels_mode = vertex.get_resize_bilinear_pixels_mode()

        # option 1, rely on shape inference (stupid simple)
        if vertex.op in RESIZE_OPS:
            resize_sizes = vertex.get_output_shapes(convert_to_nhwc=False)
            if resize_sizes:
                consumed_upsample_nodes = vertex.consume_upsample_nodes()
                consumed_vertices.extend(consumed_upsample_nodes)
                h_sizes, w_sizes, d_sizes = split_resize_sizes(resize_sizes[0])
                _, _, input_disparity = split_resize_sizes(vertex.get_input_shapes(convert_to_nhwc=False)[0])
                layer = ResizeLayer.create(
                    vertex.name,
                    vertex.input,
                    resize_method,
                    input_disparity=input_disparity,
                    h_sizes=h_sizes,
                    w_sizes=w_sizes,
                    d_sizes=d_sizes,
                    resize_bilinear_pixels_mode=pixels_mode,
                )

        # option 2, converted from keras model
        elif vertex.op == "Unsqueeze":
            input_disparity = 1
            keras_nodes, resize_sizes = vertex.get_keras_unsqueeze_resize_nearest_block_info()
            if keras_nodes is None:
                torch_nodes, resize_sizes = vertex.get_torch_unsqueeze_resize_nearest_block_info()
                _, _, input_disparity = split_resize_sizes(vertex.get_input_shapes(convert_to_nhwc=False)[0])
                consumed_vertices.extend(torch_nodes)
            else:
                consumed_vertices.extend(keras_nodes)

            h_sizes, w_sizes, d_sizes = split_resize_sizes(resize_sizes)
            layer = ResizeLayer.create(
                vertex.name,
                vertex.input,
                resize_method,
                input_disparity=input_disparity,
                h_sizes=h_sizes,
                w_sizes=w_sizes,
                d_sizes=d_sizes,
                resize_bilinear_pixels_mode=pixels_mode,
            )

        # option 3, converted from torch model
        elif vertex.op == "Tile":
            resize_sizes, consumed_resize_nodes = vertex.get_torch_tile_resize_sizes()
            consumed_vertices.extend(consumed_resize_nodes)
            h_sizes, w_sizes, d_sizes = split_resize_sizes(resize_sizes)
            layer = ResizeLayer.create(
                vertex.name,
                vertex.input,
                resize_method,
                h_sizes=h_sizes,
                w_sizes=w_sizes,
                d_sizes=d_sizes,
            )

        # option 4, another torch variant of resize nearest
        elif vertex.op == "Reshape":
            resize_sizes, consumed_resize_nodes = vertex.get_torch_reshape_expand_sizes()
            consumed_vertices.extend(consumed_resize_nodes)
            h_sizes, w_sizes, d_sizes = split_resize_sizes(resize_sizes)
            layer = ResizeLayer.create(
                vertex.name,
                vertex.input,
                resize_method,
                h_sizes=h_sizes,
                w_sizes=w_sizes,
                d_sizes=d_sizes,
            )

        # option 5, fallback to legacy parsing of layer attributes based on various implementations
        if not layer:
            resize_vertices, consumed_resize_nodes = vertex.get_resize_dynamic_sizes()
            if resize_vertices:
                consumed_vertices.extend(consumed_resize_nodes)
                layer = ResizeLayer.create(
                    vertex.name,
                    vertex.input,
                    resize_method,
                    resize_bilinear_pixels_mode=pixels_mode,
                )
                self._resize_layers_meta_vertices[layer] = resize_vertices
            else:
                upscale_factors, consumed_upscale_nodes = vertex.get_resize_upscale_factors()
                if upscale_factors:
                    consumed_vertices.extend(consumed_upscale_nodes)
                    layer = ResizeLayer.create(
                        vertex.name,
                        vertex.input,
                        resize_method,
                        upscale_factors=upscale_factors,
                        resize_bilinear_pixels_mode=pixels_mode,
                    )
                else:
                    resize_sizes, consumed_resize_nodes = vertex.get_resize_const_sizes()
                    if resize_sizes:
                        consumed_vertices.extend(consumed_resize_nodes)
                        layer_resize_sizes = resize_sizes[2:4] if len(resize_sizes) == 4 else resize_sizes
                        h_sizes, w_sizes, d_sizes = split_resize_sizes(layer_resize_sizes)
                        layer = ResizeLayer.create(
                            vertex.name,
                            vertex.input,
                            resize_method,
                            h_sizes=h_sizes,
                            w_sizes=w_sizes,
                            d_sizes=d_sizes,
                            resize_bilinear_pixels_mode=pixels_mode,
                        )

        # fallback, couldn't parse this layer
        if not layer:
            raise UnsupportedResizeLayerError(
                f"Failed to create resize layer at vertex {vertex.name}. "
                "Either upscale factor or resize shapes weren't found.",
            )

        self._add_layer(layer)
        self._add_original_names(layer, consumed_vertices)
        return consumed_vertices

    def _create_slice_layer(self, vertex, is_channel_shuffle=False):
        if is_channel_shuffle:
            info_func = vertex.get_channel_shuffle_slice_info
        else:
            info_func = vertex.get_gather_slice_info if vertex.op in GATHER_OPS else vertex.get_slice_info

        h_slice, w_slice, f_slice, groups, consumed_vertices = info_func()
        layer = SliceLayer.create(vertex.name, vertex.input, h_slice, w_slice, f_slice, groups=groups)
        self._add_layer(layer)
        self._add_original_names(layer, consumed_vertices)
        return consumed_vertices

    def _create_normalization_layer(self, vertex):
        if vertex.is_mul_by_2_ew_add():
            mean = [0.0]
            std = [0.5]
            consumed_vertices = []
        else:
            mean, std, consumed_vertices = vertex.get_normalization_info()

        # if std = 1 or mean = 0, normalization is actually standalone mul, sub, or add
        is_mul_add = all(x == 1 for x in std) or all(x == 0 for x in mean)
        norm_type = NormalizationType.mul_and_add if is_mul_add else NormalizationType.normalization
        activation = vertex.get_normalization_activation(consumed_vertices)

        # tiles mean and std to match the number of groups * features if needed
        if (
            vertex.output_format
            and Dims.GROUPS in vertex.output_format
            and vertex.get_input_shapes(convert_to_nhwc=False)[0][vertex.output_format.index(Dims.GROUPS)] != 1
        ):
            std = (
                np.tile(std, vertex.get_input_shapes(convert_to_nhwc=False)[0][vertex.output_format.index(Dims.GROUPS)])
                if len(std) != 1 and len(std) != vertex.get_input_shapes()[0][-1]
                else std
            )
            mean = (
                np.tile(
                    mean,
                    vertex.get_input_shapes(convert_to_nhwc=False)[0][vertex.output_format.index(Dims.GROUPS)],
                )
                if len(mean) != 1 and len(mean) != vertex.get_input_shapes()[0][-1]
                else mean
            )

        layer = NormalizationLayer.create(
            vertex.name,
            vertex.input,
            mean,
            std,
            normalization_type=norm_type,
            activation=activation,
            output_shapes=vertex.get_output_shapes(),
        )

        self._add_layer(layer)
        is_shape_expand, _ = vertex.is_shape_expand_norm()
        if not is_shape_expand:
            self._add_original_names(layer, consumed_vertices)

        return consumed_vertices

    def _create_reduce_max_layer(self, vertex):
        if not vertex.is_valid_reduce_max_min():
            raise UnsupportedReduceMaxLayerError(
                f"Failed to create reduce max layer at vertex {vertex.name}. Reduce "
                f"max is only supported on the features axis, and with keepdim=True",
            )
        groups = vertex.get_hailo_reduce_groups()
        layer = ReduceMaxLayer.create(
            vertex.name,
            vertex.input,
            groups=groups,
            output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
        )
        self._add_layer(layer)

    def _create_grouped_reduce_max_layer(self, vertex):
        groups, consumed_vertices = vertex.get_grouped_reduce_max_info()
        layer = ReduceMaxLayer.create(
            vertex.name,
            vertex.input,
            groups=groups,
            output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
        )
        self._add_layer(layer)
        self._add_original_names(layer, consumed_vertices)

        return consumed_vertices

    def _create_reduce_min_layer(self, vertex):
        if not vertex.is_valid_reduce_max_min():
            raise UnsupportedReduceMinLayerError(
                f"Failed to create reduce min layer at vertex {vertex.name}. Reduce "
                f"min is only supported on the features axis, and with keepdim=True",
            )
        groups = vertex.get_hailo_reduce_groups()
        layer = ReduceMinLayer.create(
            vertex.name,
            vertex.input,
            groups=groups,
            output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
        )
        self._add_layer(layer)

    def _create_space_to_depth_layer(self, vertex):
        if vertex.is_flatten_width_over_features_reshape():
            block_sizes, consumed_vertices = vertex.get_flatten_width_over_features_reshape_info()
            s2d_type = SpaceToDepthType.classic_dcr
        else:
            block_size, consumed_vertices = vertex.get_space_to_depth_info()
            if block_size != 2:
                raise UnsupportedSpaceToDepthLayerError(
                    f"Failed to create space to depth layer focus type at vertex "
                    f"{vertex.name}. Block size is {block_size}, but only block size 2 "
                    f"is supported for this layer",
                )
            block_sizes = [block_size] * 2
            s2d_type = SpaceToDepthType.focus

        layer = SpaceToDepthLayer.create(
            vertex.name,
            vertex.input,
            block_sizes,
            space_to_depth_type=s2d_type,
            output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
        )
        self._add_layer(layer)
        self._add_original_names(layer, consumed_vertices)
        return consumed_vertices

    def _create_external_pad_layer(self, vertex):
        _, pads, padding_const_value, consumed_vertices = vertex.get_vertex_padding()
        layer = ExternalPadLayer.create(
            vertex.name,
            vertex.input,
            padding_vals=pads,
            padding_const_value=padding_const_value,
            output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
        )
        self._add_layer(layer)
        self._add_original_names(layer, consumed_vertices)
        return consumed_vertices

    def _create_reduce_sum_layer(self, vertex):
        valid, axes, groups, interleaved_groups = vertex.get_reduce_sum_info()
        if interleaved_groups:
            pred_layer = self._vertices_to_layers.get(next(iter(self.graph.predecessors(vertex))))
            if pred_layer is None or pred_layer.op != LayerType.concat or pred_layer.group_sizes is not None:
                valid = False

        if not valid:
            raise UnsupportedReduceSumLayerError(f"Failed to create reduce sum layer at vertex {vertex.name}.")

        layer = ReduceSumLayer.create(
            vertex.name,
            vertex.input,
            axes,
            groups=groups,
            interleaved_groups=interleaved_groups,
            output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
        )
        self._add_layer(layer)

    def _create_grouped_reduce_sum_layer(self, vertex):
        height_groups, reduce_axes, consumed_vertices = vertex.get_grouped_reduce_sum_info()
        layer = ReduceSumLayer.create(vertex.name, vertex.input, reduce_axes=reduce_axes, height_groups=height_groups)

        self._add_layer(layer)
        self._add_original_names(self._current_layer, [vertex, *consumed_vertices])
        return consumed_vertices

    def _create_reduce_l2_layer(self, vertex):
        valid, axes, groups, interleaved_groups = vertex.get_reduce_sum_info()
        if not valid or interleaved_groups:
            raise UnsupportedReduceL2LayerError(
                f"Failed to create reduce l2 layer at vertex {vertex.name}. "
                f"Reduce l2 is only supported on the height, width or features axis, and with keepdim=True",
            )

        layer = ReduceL2Layer.create(
            vertex.name,
            vertex.input,
            axes,
            groups=groups,
            output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
        )
        self._add_layer(layer)

    def _create_reduce_sum_square_layer(self, vertex):
        valid, axes, groups, interleaved_groups = vertex.get_reduce_sum_info()
        if not valid or groups != 1 or interleaved_groups:
            raise UnsupportedReduceSumSquareLayerError(
                f"Failed to create reduce l2 layer at vertex {vertex.name}. "
                f"Reduce Sum Square is only supported on the height, width or features axis, and with keepdim=True",
            )

        layer = ReduceSumSquareLayer.create(
            vertex.name,
            vertex.input,
            axes,
            groups=groups,
            output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
        )
        self._add_layer(layer)

    def _create_bias_add_layer(self, vertex):
        bias, consumed_vertices = vertex.get_bias()

        layer = BiasAddLayer.create(
            vertex.name,
            vertex.input,
            bias,
            output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
        )
        self._add_layer(layer)
        return consumed_vertices

    def _create_square_layer(self, vertex):
        consumed_vertices = []
        if vertex.op == "Pow":
            power, consumed_vertices = vertex.get_power()
            if power != 2.0:
                raise UnsupportedSquareLayerError(
                    f"Pow operator {vertex.name} can only be supported as square (got power of {power})",
                )

        layer = FeatureMultiplierLayer.create(
            vertex.name,
            vertex.input,
            feature_multiplier_type=FeatureMultiplierType.square,
            output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
        )

        self._add_layer(layer)
        return consumed_vertices

    def _create_matmul_layer(self, vertex):
        groups, should_transpose_input, kernel, matmul_input, _ = vertex.get_matmul_layer_info()

        if kernel is not None:
            # kernel is rank4 due to generic way of parsing kernels in matmul/conv/dense layers
            kernel = np.transpose(kernel, [0, 2, 1, 3])

            # kernel is returned to rank3 to match definitions of const_input layers
            kernel = np.reshape(kernel, [1, kernel.shape[1], kernel.shape[2] * kernel.shape[3]])
            const_output_shape = [-1, kernel.shape[0], kernel.shape[1], kernel.shape[2]]
            matmul_input = vertex.get_const_layer_input_order()
            const_layer = ConstInputLayer.create(f"{vertex.name}_input", [const_output_shape], kernel)
            self._add_layer(const_layer, has_edge=False)

        layer = MatmulLayer.create(
            vertex.name,
            matmul_input,
            should_transpose_input=should_transpose_input,
            groups=groups,
            output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
        )

        self._add_layer(layer)

        if kernel is not None:
            self._layers_graph.add_edge(const_layer, layer)
            layer.add_input_by_vertex(const_layer, input_name=const_layer.original_names[0])

    def _create_transpose_layer(self, vertex):
        perm = vertex.get_transpose_perm()
        # translating channels first perm to channels last
        nchw_to_nhwc = {0: 0, 1: 3, 2: 1, 3: 2}
        perm = [nchw_to_nhwc[i] for i in perm]
        perm = [perm[0], perm[2], perm[3], perm[1]]

        layer = TransposeLayer.create(
            vertex.name,
            vertex.input,
            perm,
            output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
        )
        self._add_layer(layer)
        return []

    def _create_ew_mean_layer(self, vertex):
        layer = EWMeanLayer.create(
            vertex.name,
            vertex.input,
            output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
        )
        self._add_layer(layer)

        return []

    def _create_reduce_mean_layer(self, vertex):
        axes, groups, consumed_vertices = vertex.get_reduce_mean_info()
        output_shapes = vertex.get_output_shapes(validate_zero_dims=True)
        layer = ReduceMeanLayer.create(vertex.name, vertex.input, axes, groups=groups, output_shapes=output_shapes)

        for consumed_vertex in consumed_vertices:
            layer.add_original_name(consumed_vertex.name)
        self._add_layer(layer)
        return consumed_vertices

    def _create_equal_layer(self, vertex):
        constant_input, consumed_vertices, _ = vertex.get_normalization_input_raw_values()
        if constant_input is None:
            if len(vertex.input) == 1:
                raise UnsupportedEqualLayerError(
                    f"Unable to find equal inputs for {vertex.name} - should have data input and either data input or "
                    f"value",
                )
        elif len(constant_input.shape) == 3:
            constant_input = constant_input.transpose([1, 2, 0])

        layer = EqualLayer.create(
            vertex.name,
            vertex.input,
            constant_input,
            output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
        )
        self._add_layer(layer)

        return consumed_vertices

    def _create_rnn_layer(self, vertex):
        kernel, recurrent_kernel, bias, recurrent_bias, initial_h, consumed_vertices = vertex.get_rnn_info()
        layer = RNNLayer.create(
            vertex.name,
            vertex.input,
            kernel,
            bias,
            recurrent_kernel,
            recurrent_bias,
            initial_h,
            output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
        )
        self._add_layer(layer)
        self._add_original_names(self._current_layer, [vertex, *consumed_vertices])

        return consumed_vertices

    def _create_lstm_layer(self, vertex):
        forward_params, backward_params, direction, consumed_vertices = vertex.get_lstm_info()
        layer = LSTMLayer.create(
            vertex.name,
            vertex.input,
            forward_params,
            backward_params,
            direction,
            output_shapes=vertex.get_output_shapes(validate_zero_dims=True),
        )
        self._add_layer(layer)
        self._add_original_names(self._current_layer, [vertex, *consumed_vertices])

        return consumed_vertices

    def _create_gru_layer(self, vertex):
        (
            hidden_size,
            linear_before_reset,
            kernel,
            recurrence_kernel,
            bias,
            sequence_lens,
            initial_h,
            consumed_vertices,
        ) = vertex.get_gru_info()
        layer = GRULayer.create(
            vertex.name,
            vertex.input,
            hidden_size,
            linear_before_reset,
            kernel,
            recurrence_kernel,
            bias,
            sequence_lens,
            initial_h,
            vertex.get_output_shapes(validate_zero_dims=True),
        )
        self._add_layer(layer)
        self._add_original_names(self._current_layer, [vertex, *consumed_vertices])

        return consumed_vertices

    def _create_log_softmax_layer(self, vertex):
        consumed_vertices = []
        axis = vertex.get_log_softmax_axis()
        output_shapes = vertex.get_output_shapes(validate_zero_dims=True)
        layer = LogSoftmaxLayer.create(vertex.name, vertex.input, axis, output_shapes)
        self._add_layer(layer)
        self._add_original_names(self._current_layer, [vertex, *consumed_vertices])

        return consumed_vertices

    def _create_one_hot_layer(self, vertex):
        consumed_vertices, out_shapes = vertex.get_one_hot_info()
        layer = OneHotLayer.create(vertex.name, vertex.input, out_shapes[0][-1], out_shapes)
        self._add_original_names(layer, consumed_vertices)
        self._add_layer(layer)
        return consumed_vertices

    def _create_scatter_nd_layer(self, vertex):
        consumed_vertices = []
        vertex.validate_scatter_nd_info()
        output_shapes = vertex.get_output_shapes(validate_zero_dims=True)
        layer = ScatterNDLayer.create(vertex.name, vertex.input, vertex.get_const_input_values(), output_shapes)
        self._add_original_names(layer, consumed_vertices)
        self._add_layer(layer)

        return consumed_vertices

    def _handle_multi_output_copies_split(self):
        """
        The goal of this function is to solve edge cases where multiple outputs of splitters go to different layers.
        """
        for layer in list(self._layers_graph):
            if layer.op not in [LayerType.feature_splitter, LayerType.spatial_splitter, LayerType.width_splitter]:
                continue

            if not layer.original_names:
                continue

            # map original output names to vertices in the nn graph
            orig_vertex = self.graph.vertices_by_name[layer.original_names[0]]
            orig_output_to_orig_successors = {}
            for orig_output in list(orig_vertex.output):
                orig_succs = [
                    x
                    for x in self.graph.successors(orig_vertex)
                    if any(orig_output.split(":")[-1] == y.split(":")[-1] for y in x.input)
                ]
                orig_output_to_orig_successors[orig_output] = orig_succs

            # there exists a successor that has multiple inputs from copies of the splitter output
            multi_inputs_contained_cond = False
            for succ in self.graph.successors(orig_vertex):
                relevant_inputs = [
                    x for x in succ.input if any(y.split(":")[-1] == x.split(":")[-1] for y in orig_vertex.output)
                ]
                if len(relevant_inputs) > 1:
                    multi_inputs_contained_cond = True

            if not multi_inputs_contained_cond:
                continue

            new_split_indices = []
            input_to_shortcut = {}
            index_to_shortcut_name = {}
            orig_output_to_orig_successors_keys = list(orig_output_to_orig_successors.keys())
            for io_index, orig_output in enumerate(orig_output_to_orig_successors):
                # determine correct input/output when there's a mismatch between num of original outputs and
                # the actual i/o copies, or when a split has multiple outputs connected to the same successor
                shortcut_required = False
                for next_key in orig_output_to_orig_successors_keys[io_index + 1 :]:
                    for cur_key_output_node in orig_output_to_orig_successors[orig_output]:
                        if len(orig_output_to_orig_successors[next_key]) > 1:
                            break
                        if cur_key_output_node in orig_output_to_orig_successors[next_key]:
                            shortcut_required = True
                            break
                if len(orig_output_to_orig_successors[orig_output]) > 1 or shortcut_required:
                    shortcut = ShortcutLayer()
                    shortcut.index = self._layers_graph.get_next_index()
                    shortcut.name = f"shortcut_{layer.name}_{io_index}"
                    shortcut.inputs = [layer.name]
                    shortcut.input_indices = [layer.index]
                    shortcut.input_shapes = [layer.output_shapes[io_index]]
                    self._layers_graph.add_node(shortcut)
                    self._layers_graph.add_edge(layer, shortcut)
                    input_to_shortcut[orig_output] = shortcut
                    index_to_shortcut_name[io_index] = shortcut.name
                    # connect shortcut to successors, while keeping the correct i/o order
                    outputs = []
                    output_indices = []
                    output_shapes = []
                    other_orig_outputs = [x for x in orig_output_to_orig_successors if x != orig_output]
                    for succ_vertex in orig_output_to_orig_successors[orig_output]:
                        succ_layer = self._layers_graph.get_layer_by_original_name(succ_vertex.name)
                        outputs.append(succ_layer.name)
                        output_indices.append(succ_layer.index)
                        output_shapes.append(layer.output_shapes[io_index])
                        self._layers_graph.add_edge(shortcut, succ_layer)

                        # if the layer we connect via shortcut has another edge to the successor, need to
                        # keep i/o order correct
                        if not any(
                            y.split(":")[-1] == x.split(":")[-1] for x in succ_vertex.input for y in other_orig_outputs
                        ):
                            self._layers_graph.remove_edge(layer, succ_layer)
                            succ_layer.replace_input_layer(layer.name, shortcut.name)
                            succ_layer.replace_input_layer(layer.name, shortcut.name)
                            succ_layer.replace_input_index(layer.index, shortcut.index)
                        else:
                            new_input_list = [None] * len(succ_vertex.input)
                            succ_vertex_preds = list(self.graph.predecessors(succ_vertex))

                            for idx, inp in enumerate(succ_vertex.input):
                                if succ_vertex.input[idx].split(":")[-1] == orig_output.split(":")[-1]:
                                    new_input_list[idx] = shortcut
                                else:
                                    pred_vertex = next(x for x in succ_vertex_preds if x.name == inp.split(":")[0])
                                    pred_layer = self._layers_graph.get_layer_by_original_name(pred_vertex.name)
                                    # check if the current input wasn't changed to shortcut
                                    current_input_name = succ_vertex.input[idx].split(":")[-1]
                                    new_input_list[idx] = input_to_shortcut.get(current_input_name, pred_layer)
                            succ_layer.inputs = [x.name for x in new_input_list]
                            succ_layer.input_indices = [x.index for x in new_input_list]
                            if succ_layer.op == LayerType.concat:
                                succ_layer.input_list = new_input_list

                    shortcut.outputs = outputs
                    shortcut.output_indices = output_indices
                    shortcut.output_shapes = output_shapes

                    new_split_indices.append(io_index)

            # update split layer info
            if not layer.split_indices:
                self._update_split_indices(layer)
            layer_split_indices = np.array(layer.split_indices)
            layer_outputs = np.array(layer.outputs, dtype="object")
            layer_output_shapes = np.array(layer.output_shapes)
            for i in new_split_indices:
                index = np.where(layer_split_indices == i)[0]
                # the first location is changed to shortcut
                layer_outputs[index[0]] = index_to_shortcut_name[i]
                if len(index) > 1:
                    layer_outputs = np.delete(layer_outputs, index[1:])
                    layer_split_indices = np.delete(layer_split_indices, index[1:])
                    layer_output_shapes = np.delete(layer_output_shapes, index[0], axis=0)

            layer.split_indices = layer_split_indices.tolist()
            layer.output_shapes = layer_output_shapes.tolist()
            layer.outputs = layer_outputs.tolist()
            layer.output_indices = [
                self._layers_graph.get_layer_by_name(output_name).index for output_name in layer.outputs
            ]

    def _detect_nms_anchors(self):
        nms_candidates_set = self._get_nms_endings_candidates()
        if (
            nms_candidates_set
            and self._anchorless_yolo_structure_detection(nms_candidates_set)
            or self._yolov5_anchors_detection(nms_candidates_set)
        ):
            self._extract_nms_config_values_from_graph()

    def _get_nms_endings_candidates(self):
        nms_candidates_set = self._get_nms_conv_endings_candidates()
        self._update_activation_end_nodes(nms_candidates_set)
        if not nms_candidates_set:  # no layers found in traversing, fallback to search over real out nodes
            nms_candidates_set = self._layers_graph.get_real_output_layers()
        nms_candidates_set.sort(
            reverse=True, key=lambda x: x.output_shapes[0][-2]
        )  # anchors in json are sorted by height and width (smaller stride to higher stride)
        return nms_candidates_set

    def _get_nms_conv_endings_candidates(self):
        """
        For each output node, find the first set of Conv layers that have the same
        distance from the output layer or the same distance + 1, with a limit on
        the traversal depth.
        """
        output_nodes = self._layers_graph.get_real_output_layers()
        matching_layers = set()

        for output_node in output_nodes:
            queue = [(output_node, 0)]  # Queue for BFS traversal
            visited = set()
            conv_layers = {}
            min_distance = float("inf")

            while queue:
                current_node, current_distance = queue.pop(0)

                if current_node in visited:
                    continue

                visited.add(current_node)

                # Check if the node is a Conv layer
                if current_node.op == LayerType.base_conv:
                    if current_distance not in conv_layers:
                        conv_layers[current_distance] = []
                    conv_layers[current_distance].append(current_node)
                    min_distance = min(min_distance, current_distance)

                # Stop if we've reached the traversal limit
                if current_distance >= NMS_COMMON_LAYERS_DEPTH:
                    continue

                # Add predecessors to the queue
                for pred in self._layers_graph.predecessors(current_node):
                    queue.append((pred, current_distance + 1))

            # Find Conv layers that are at the same distance or distance + 1
            # Keep conv layer sets that are independent of each other
            keys_to_remove = []
            for i, key in enumerate(list(conv_layers.keys())[:-1]):
                if key in keys_to_remove:
                    continue
                for key_next in list(conv_layers.keys())[i + 1 :]:
                    if key_next in keys_to_remove:
                        continue
                    for conv_layer in conv_layers[key]:
                        if any(
                            nx.has_path(self._layers_graph, conv_layer_next, conv_layer)
                            for conv_layer_next in conv_layers[key_next]
                        ):
                            keys_to_remove.append(key_next)
                            break

            for key in keys_to_remove:
                del conv_layers[key]

            for conv_layers_set in conv_layers.values():
                matching_layers.update(conv_layers_set)

        return list(matching_layers)

    def _get_end_nodes_activations_dict(self, nms_candidates_set):
        """
        Get a dictionary of Conv layers that have a single activation layer successor
        """
        replacement_dict = {}
        for layer in nms_candidates_set:
            successors = list(self._layers_graph.successors(layer))
            if (
                len(successors) == 1
                and successors[0].op == LayerType.base_activation
                and len(list(self._layers_graph.predecessors(successors[0]))) == 1
            ):
                replacement_dict[layer] = successors[0]
        return replacement_dict

    def _update_activation_end_nodes(self, nms_candidates_set):
        """
        Replace Conv layers with their corresponding activation layers successors in the NMS candidates
        """
        replacement_dict = self._get_end_nodes_activations_dict(nms_candidates_set)
        if not all(nms_candidate in replacement_dict for nms_candidate in nms_candidates_set):
            # We need to have some regular (conv) layers for detection, replacing only if not all convs are replaced.
            for idx, nms_candidate in enumerate(nms_candidates_set):
                if nms_candidate in replacement_dict:
                    nms_candidates_set[idx] = replacement_dict[nms_candidate]

    def _yolov5_anchors_detection(self, nms_candidates_set):
        input_height = self._layers_graph.get_input_layers()[0].input_shape[1]
        anchors = {"info": {}}
        convs = [layer for layer in nms_candidates_set if layer.op == LayerType.base_conv]
        activations = [layer for layer in nms_candidates_set if layer.op == LayerType.base_activation]
        log_msg = ""
        if (
            len(nms_candidates_set) in [3, 4]
            and len(nms_candidates_set) == len(convs)
            or len(nms_candidates_set) == 4
            and len(convs) == 3
            and len(activations) == 1
        ):
            for layer in convs:
                vertex = self._get_vertex_by_layer(layer)
                anchors["info"].update(vertex.get_nms_last_conv_info(input_height))
            if not anchors["info"]:
                # the model might be yolov5 or its variant but does not contain the postprocess block
                # if so, default anchors will be loaded
                convs_output_shapes = [conv.output_shape for conv in convs]
                classes_of_each_branch = [
                    self._get_nms_classes_from_conv_output_shape(vertex_output_shape)
                    for vertex_output_shape in convs_output_shapes
                ]
                if set(classes_of_each_branch):
                    # all branches suites to the same num of classes
                    # might be found nms structure
                    for bbox_decoder in NMSPostprocessCommand.get_value_from_default_config_json(
                        "bbox_decoders",
                        NMSMetaArchitectures.YOLOV5,
                    ):
                        anchor_info = {
                            key: value
                            for key, value in bbox_decoder.items()
                            if key
                            in [BBoxDecodersInfo.W.value, BBoxDecodersInfo.H.value, BBoxDecodersInfo.STRIDE.value]
                        }

                        # extracts the name of the conv by the stride value
                        conv_name = [
                            conv.original_names[0]
                            for conv in convs
                            if conv.output_shape[1] * anchor_info[BBoxDecodersInfo.STRIDE.value] == input_height
                        ]
                        if len(conv_name) != 1:
                            # the current does not fit to NMS post-process
                            return False

                        anchors["info"].update({conv_name[0]: anchor_info})
                    log_msg = "Default values of NMS anchors were loaded to NMS config json"

            if anchors["info"]:
                # NMS meta architecture was detected
                anchors["meta_arch"] = (
                    NMSMetaArchitectures.YOLOV5_SEG if len(activations) == 1 else NMSMetaArchitectures.YOLOV5
                )
                log_msg = (
                    f"NMS structure of {anchors['meta_arch'].value} (or equivalent architecture) was detected. "
                ) + log_msg
                self._logger.info(log_msg)
                self._layers_graph.detected_anchors = anchors
                self._layers_graph.detected_anchors["end_nodes"] = [
                    layer.original_names[-1] for layer in nms_candidates_set
                ]
                return True
        return False

    def _get_nms_classes_from_conv_output_shape(self, output_shape):
        return (output_shape[-1] / DEFAULT_YOLO_ANCHORS) - DEFAULT_BOX_AND_OBJ_PXLS

    def _get_vertex_by_layer(self, layer):
        vertex = self.graph.get_vertex_by_valid_name(layer.original_names[0])
        if vertex:
            return vertex
        raise UnsupportedOperationError(f"Could not find a single vertex matching layer: {layer.name}.")

    def _detect_late_nodes_anchorless_structure(self, real_out_layers):
        orig_vertices = [self._get_vertex_by_layer(layer) for layer in real_out_layers]
        if all(
            layer.op == LayerType.format_conversion
            and layer.conversion_type == FormatConversionType.transpose_width_features
            for layer in real_out_layers
        ):
            if len(real_out_layers) == 1:
                meta_arch = NMSMetaArchitectures.YOLOX
                found_structure, end_nodes = orig_vertices[0].is_end_of_yolox_structure()
            else:
                meta_arch = NMSMetaArchitectures.YOLOV6
                found_structure, end_nodes = orig_vertices[0].is_end_of_yolov6_structure(orig_vertices[1])
            return meta_arch, found_structure, end_nodes
        return None, False, []

    def _detect_anchorless_convs_meta_arch(self, real_out_layers):
        meta_arch = None
        activations = [
            layer
            for layer in real_out_layers
            if layer.op == LayerType.base_activation and layer.activation == ActivationType.sigmoid
        ]

        if len(real_out_layers) == 1 and real_out_layers[0].op == LayerType.concat:
            reg_layers = []
            for pred in list(self._layers_graph.predecessors(real_out_layers[0])):
                for pred_of_pred in list(self._layers_graph.predecessors(pred)):
                    if pred_of_pred.op == LayerType.concat:
                        conv_preds_of_concat = [
                            x for x in self._layers_graph.predecessors(pred_of_pred) if x.op == LayerType.base_conv
                        ]
                        reg_layers.extend(conv_preds_of_concat)
        else:
            reg_layers = [layer for layer in real_out_layers if layer.op == LayerType.base_conv]

        if (
            YOLOX_ACTIVATIONS_PER_REG_LAYER * len(reg_layers) == len(activations)
            or len(reg_layers) == YOLOX_TOTAL_OUTPUTS
        ):
            meta_arch = NMSMetaArchitectures.YOLOX
        elif (len(activations) == len(reg_layers) and len(reg_layers) > 0) or (
            len(reg_layers) == YOLOV6_TOTAL_OUTPUTS and not self._detect_yolov8_postprocess(reg_layers)
        ):
            meta_arch = NMSMetaArchitectures.YOLOV6
        elif not activations and self._detect_yolov8_postprocess(reg_layers):
            meta_arch = NMSMetaArchitectures.YOLOV8

        relevant_outputs = reg_layers if meta_arch == NMSMetaArchitectures.YOLOV8 else real_out_layers
        return meta_arch, relevant_outputs

    def _anchorless_yolo_structure_detection(self, nms_candidates_set):
        found_structure = False
        end_nodes = None

        if (len(nms_candidates_set) == 1 and nms_candidates_set[0].op == LayerType.concat) or (
            len(nms_candidates_set) > 2
        ):
            # Identifying the basic structure
            meta_arch, relevant_outputs = self._detect_anchorless_convs_meta_arch(nms_candidates_set)

            # Making sure of the whole structure
            if meta_arch and self._are_end_convs_of_anchorless_yolo(relevant_outputs, meta_arch):
                found_structure = True
                end_nodes = [x.original_names[-1] for x in relevant_outputs]

        else:
            meta_arch, found_structure, end_nodes = self._detect_late_nodes_anchorless_structure(nms_candidates_set)

        if found_structure:
            self._layers_graph.detected_anchors["meta_arch"] = meta_arch
            self._logger.info(f"NMS structure of {meta_arch.value} (or equivalent architecture) was detected.")
            if end_nodes:
                self._logger.info(
                    "In order to use HailoRT post-processing capabilities, these end node names "
                    f"should be used: {' '.join(end_nodes)}.",
                )
                self._layers_graph.detected_anchors["end_nodes"] = end_nodes
        return found_structure

    def _detect_yolov8_postprocess(self, reg_layers):
        for layer in reg_layers:
            vertex = self._get_vertex_by_layer(layer)
            yolov8_possible_conv = vertex.get_possible_yolov8_postprocess()
            if yolov8_possible_conv is None:
                return False
        return True

    def _are_end_convs_of_anchorless_yolo(self, output_layers, meta_arch):
        # YOLOX has 3 branches, each branch has three output layer - reg_layer, obj_layer and cls_layer
        # each output in a branch must have the same h, w
        # while the number of channels must be 4 in reg_layer, 1 in obj_layer and number of classes in cls_layers
        # YOLOv6 has 3 branches, each branch has 2 output layers, same as YOLOX but without the obj layers
        unique_h_w = {tuple(conv.output_shapes[0][1:3]) for conv in output_layers}
        if any(h_w[0] == 1 for h_w in unique_h_w):
            # if the h of the outputs is equal to 1 NMS structure detection is disabled due to ambiguity between models
            # (detr and yolov6)
            return False

        classes = None
        exp_outputs_per_branch = YOLO_OUTPUTS_PER_BRANCH[meta_arch]
        f_out = get_f_out_by_meta_arch(meta_arch)
        for h_w in unique_h_w:
            same_h_w_layers = [conv for conv in output_layers if tuple(conv.output_shapes[0][1:3]) == h_w]
            no_obj_in_arch = len(same_h_w_layers) != 2
            if len(same_h_w_layers) not in exp_outputs_per_branch:
                return False

            reg_layer = [conv for conv in same_h_w_layers if conv.output_shapes[0][-1] == f_out]
            obj_layer = [conv for conv in same_h_w_layers if conv.output_shapes[0][-1] == 1 and no_obj_in_arch]
            cls_layer = [
                conv
                for conv in same_h_w_layers
                if (classes and conv.output_shapes[0][-1] == classes)
                or (not classes and conv not in reg_layer + obj_layer)
            ]
            if (not len(reg_layer) == len(cls_layer) == 1) or (
                meta_arch == NMSMetaArchitectures.YOLOX and len(obj_layer) != 1
            ):
                return False

            if not classes and len(cls_layer) == 1:
                classes = cls_layer[0].output_shapes[0][-1]
        return True

    def _extract_nms_config_values_from_graph(self):
        config_values = {}
        nms_node = [x for x in self.graph.vertices_by_name.values() if x.op in NMS_OPS]
        if len(nms_node) == 1:
            config_values = nms_node[0].get_nms_config_values()

        if self._layers_graph.detected_anchors["meta_arch"] == NMSMetaArchitectures.YOLOV8:
            out_layer = self._layers_graph.get_real_output_layers()[0]
            vertex = self._get_vertex_by_layer(out_layer)
            reg_length = vertex.get_yolov8_reg_length()
            if reg_length is not None:
                config_values[NMSProperties.REGRESSION_LENGTH.value] = reg_length

        if config_values:
            self._layers_graph.detected_anchors["config_values"] = config_values

    def _replace_feature_splitter_with_slices(self):
        """
        This function replaces feature splitter with slices in case of not all the outputs of the split are used.
        """
        for vertex in self.graph.nodes:
            if vertex.op not in SPLIT_OPS or not vertex.in_valid_subgraph:
                continue

            # the split in the graph might be replaced with other layer type in the layers graph
            feature_splitter_layer = self._layers_graph.get_layer_by_original_name(vertex.name)
            if feature_splitter_layer.op != LayerType.feature_splitter:
                continue

            # extracts the used output by the split's successors
            split_outputs = vertex.output
            vertex_successors = list(self.graph.successors(vertex))
            index_to_successors = {i: [] for i in range(len(split_outputs))}
            for successor in vertex_successors:
                used_split_output = [
                    name
                    for name in successor.input
                    if any(name.endswith(split_output) for split_output in split_outputs)
                ]

                # extracts the name of the outputs (it's followed by :) Split:input.8 -> input.8
                for split_output in used_split_output:
                    curr_used_split_output = re.split("(?<!:):(?!:)", split_output, maxsplit=1)
                    curr_used_split_output = (
                        curr_used_split_output[1] if len(curr_used_split_output) > 1 else curr_used_split_output[0]
                    )
                    used_index = list(split_outputs).index(curr_used_split_output)
                    index_to_successors[used_index].append(successor)

            # checks whether at least one output is unused if so, replaces feature splitter with slices
            total_used_split_outputs = len(list(itertools.chain.from_iterable(list(index_to_successors.values()))))
            not_all_outputs_used = (
                any(len(outputs) == 0 for outputs in index_to_successors.values())
                and any(len(outputs) > 0 for outputs in index_to_successors.values())
                and len(vertex_successors) == total_used_split_outputs
            )

            if not_all_outputs_used:
                for index, output_successors in index_to_successors.items():
                    if len(output_successors) == 0:
                        # unused output
                        continue

                    # replaces the feature splitter layer with slice
                    start_index = sum(feature_splitter_layer.split_sizes[:index])
                    end_index = start_index + feature_splitter_layer.split_sizes[index]
                    slice_layer = SliceLayer.create(
                        vertex.name,
                        vertex.input,
                        height_slice=[0, 0],
                        width_slice=[0, 0],
                        features_slice=[start_index, end_index],
                    )
                    slice_layer.name = f"slice{index + 1}_{feature_splitter_layer.name}"
                    slice_layer.index = self._layers_graph.get_next_index()
                    slice_layer.original_names = feature_splitter_layer.original_names.copy()
                    slice_layer.input_indices = feature_splitter_layer.input_indices.copy()
                    slice_layer.inputs = feature_splitter_layer.inputs.copy()
                    slice_layer.input_shapes = feature_splitter_layer.input_shapes.copy()
                    slice_layer.output_shapes = [feature_splitter_layer.output_shapes[index]]
                    self._layers_graph.add_node(slice_layer)

                    # organizes pred outputs
                    pred = next(iter(self._layers_graph.predecessors(feature_splitter_layer)))  # FS has one input
                    pred.append_output_layer(slice_layer.name)
                    pred.append_output_index(slice_layer.index)
                    pred.append_output_shape(slice_layer.input_shape)
                    self._layers_graph.add_edge(pred, slice_layer)
                    successors = [
                        self._layers_graph.get_layer_by_original_name(output_successor.name)
                        for output_successor in output_successors
                    ]

                    for succ_layer in successors:
                        # the successor layer might be preceded by a padding layer and not by the feature splitter
                        succ_pred = next(iter(self._layers_graph.predecessors(succ_layer)))
                        if succ_pred != feature_splitter_layer and succ_pred.op == LayerType.external_pad:
                            succ_layer = succ_pred

                        succ_layer.replace_input_layer(feature_splitter_layer.name, slice_layer.name)
                        succ_layer.replace_input_index(feature_splitter_layer.index, slice_layer.index)
                        succ_layer.replace_input_shape(feature_splitter_layer.name, slice_layer.output_shapes)

                        slice_layer.append_output_layer(succ_layer.name)
                        slice_layer.append_output_index(succ_layer.index)

                        self._layers_graph.add_edge(slice_layer, succ_layer)
                        self._layers_graph.remove_edge(feature_splitter_layer, succ_layer)

                # removes feature splitter information from pred layer
                feature_splitter_index = pred.output_indices.index(feature_splitter_layer.index)
                del pred.output_indices[feature_splitter_index]
                del pred.output_shapes[feature_splitter_index]
                del pred.outputs[feature_splitter_index]

                self._layers_graph.remove_node(feature_splitter_layer)

    def _handle_inner_product_matmul(self):
        # structure needs the dummy conv added in the parser, matmul input order should be conv as
        # activation, and other pred as data input.
        for layer in list(self._layers_graph):
            if layer.op == LayerType.matmul:
                preds = list(self._layers_graph.predecessors(layer))
                if (
                    preds[0].op == LayerType.format_conversion
                    and preds[0].conversion_type == FormatConversionType.transpose_width_features
                    and preds[1].op == LayerType.base_conv
                ):
                    transpose = preds[0]
                    conv = preds[1]
                    conv.outputs = [layer.name]
                    conv.output_indices = [layer.index]
                    conv.output_shapes = conv.output_shapes[0]

                    conv_pred = next(iter(self._layers_graph.predecessors(conv)))
                    conv_pred.outputs = [conv.name, transpose.name]
                    conv_pred.output_indices = [conv.index, transpose.index]
                    transpose.inputs = [conv_pred.name]
                    transpose.input_indices = [conv_pred.index]

                    self._layers_graph.remove_edge(conv, transpose)
                    self._layers_graph.add_edge(conv_pred, transpose)

    def _handle_tokens_matmul(self):
        for layer in list(self._layers_graph):
            if layer.op != LayerType.matmul:
                continue

            vertex = self._graph.vertices_by_name.get(layer.original_names[0])
            if not (vertex and vertex.is_tokens_matmul()):
                continue

            next_idx = self._layers_graph.get_next_index()

            for i, pred in enumerate(self._layers_graph.predecessors(layer)):
                name = f"format_conversion{i}_{layer.name}"
                format_conversion = self._create_layer(
                    FormatConversionLayer,
                    next_idx,
                    name,
                    [pred.name],
                    [pred.index],
                    [layer.name],
                )
                format_conversion.conversion_type = FormatConversionType.transpose_width_features
                format_conversion.groups = layer.groups
                format_conversion.original_names = layer.original_names.copy()
                next_idx += 1

                self._layers_graph.remove_edge(pred, layer)
                self._layers_graph.add_edge(pred, format_conversion)
                self._layers_graph.add_edge(format_conversion, layer)
                layer.replace_input_layer(pred.name, format_conversion.name)
                layer.replace_input_index(pred.index, format_conversion.index)
                pred.replace_output_layer(layer.name, format_conversion.name)

            self._logger.debug(f"Added format conversion layers before tokens matmul layer {layer.name}")

    def _prevent_transpose_hw_suggestion(self):
        if self._recommended_start_names:
            for name in self._recommended_start_names:
                start_node = self.graph.vertices_by_name.get(name)
                if start_node is not None and start_node.is_height_width_transpose():
                    self._recommended_start_names.remove(start_node.name)
                    self._recommended_start_names.update([succ.name for succ in self.graph.successors(start_node)])
        if self._recommended_end_names:
            for name in self._recommended_end_names:
                end_node = self.graph.vertices_by_name.get(name)
                if end_node is not None and end_node.is_height_width_transpose():
                    self._recommended_end_names.remove(end_node.name)
                    self._recommended_end_names.update([pred.name for pred in self.graph.predecessors(end_node)])
