import re
from copy import deepcopy

import numpy as np
from onnx import numpy_helper

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DEFAULT_CONCAT_AXIS,
    ConcatAxis,
    NMSProperties,
)
from hailo_sdk_client.exposed_definitions import Dims
from hailo_sdk_client.model_translator.edge_nn_translator import INPUT_OP
from hailo_sdk_client.model_translator.exceptions import (
    CantFindGraphStartError,
    UnexpectedNodeError,
    UnsupportedActivationLayerError,
    UnsupportedConcatLayerError,
    UnsupportedConv3DError,
    UnsupportedConvLayerError,
    UnsupportedEinsumLayerError,
    UnsupportedFeatureSplitterLayerError,
    UnsupportedGatherLayerError,
    UnsupportedGRULayerError,
    UnsupportedInputFormatError,
    UnsupportedInputShapesError,
    UnsupportedL2NormLayerError,
    UnsupportedLayerNormLayerError,
    UnsupportedLogitsLayerError,
    UnsupportedLogSoftmaxLayerError,
    UnsupportedLSTMLayerError,
    UnsupportedMultLayerError,
    UnsupportedNormalizationLayerError,
    UnsupportedPaddingError,
    UnsupportedReduceMeanLayerError,
    UnsupportedReshapeError,
    UnsupportedResizeLayerError,
    UnsupportedRNNLayerError,
    UnsupportedScatterNDError,
    UnsupportedSliceLayerError,
    UnsupportedSoftmaxLayerError,
    UnsupportedZeroDimensionShapeError,
)
from hailo_sdk_client.model_translator.graph_lookup import (
    BwdChainNode,
    FwdChainNode,
    get_all_nodes_from_possible_chains,
    get_all_nodes_in_chain,
    get_node_from_possible_chains,
    look_and_validate,
    look_for_node,
)
from hailo_sdk_client.model_translator.nn_graph import NNGraph, NNGraphNode
from hailo_sdk_client.model_translator.onnx_translator.onnx_translator_definitions import Conv3DInfo
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import (
    ActivationType,
    LayerType,
    PaddingType,
    ResizeBilinearPixelsMode,
    ResizeMethod,
    TemporaryPaddingType,
)
from hailo_sdk_common.hailo_nn.hn_layers import BatchNormValues
from hailo_sdk_common.logger.logger import default_logger

INPUT_OPS = [INPUT_OP]
OUTPUT_OPS = ["output"]
CONV2D_OPS = ["Conv", "ConvTranspose"]
DENSE_OPS = ["Gemm", "MatMul"]
POOL_OPS = ["AveragePool", "MaxPool", "GlobalAveragePool", "GlobalMaxPool", "Mean", "ReduceMean"]
BN_OPS = ["BatchNormalization"]
ACTIVATION_OPS = [
    "Relu",
    "Elu",
    "LeakyRelu",
    "Exp",
    "Sigmoid",
    "Clip",
    "Tanh",
    "Abs",
    "Sign",
    "Greater",
    "PRelu",
    "Softplus",
    "Erf",
    "Sqrt",
    "Less",
    "Log",
    "HardSigmoid",
    "Min",
    "Max",
    "HardSwish",
    "Softsign",
    "Reciprocal",
]
LOGITS_OPS = ["Softmax", "ArgMax"]
ALTERNATIVE_SOFTMAX_OPS = ["ReduceSum", "Exp", "Transpose"]
REDUCE_MAX_OPS = ["ReduceMax"]
REDUCE_MIN_OPS = ["ReduceMin"]
REDUCE_SUM_OPS = ["ReduceSum"]
REDUCE_L2_OPS = ["ReduceL2"]
REDUCE_SUM_SQUARE_OPS = ["ReduceSumSquare"]
ADD_OPS = ["Add", "Sum"]
CONCAT_OPS = ["Concat"]
PAD_OPS = ["Pad"]
SPLIT_OPS = ["Split"]
SLICE_OPS = ["Slice"]
RESIZE_OPS = ["Resize", "Upsample"]
TILE_OPS = ["Tile"]
SHUFFLE_OPS = ["Reshape", "Transpose", "DepthToSpace", "SpaceToDepth"]
DROP_OPS = ["Dropout"]
MUL_OPS = ["Mul"]
DIV_OPS = ["Div"]
SUB_OPS = ["Sub"]
NEG_OPS = ["Neg"]
MAX_OPS = ["Max"]
MIN_OPS = ["Min"]
POW_OPS = ["Pow"]
MATH_OPS = ["Cast", "Floor"]
SHAPE_OPS = ["Constant", "Shape", "ConstantOfShape", "Expand", "Squeeze", "Unsqueeze"]
EQUAL_OPS = ["Equal"]
LOGICAL_OPS = ["Where", *EQUAL_OPS]
INSTANCE_NORMALIZATION_OPS = ["InstanceNormalization"]
LAYER_NORMALIZATION_OPS = ["LayerNormalization"]
EW_OPS = ADD_OPS + SUB_OPS + MUL_OPS + DIV_OPS + MAX_OPS + MIN_OPS
CONST_OPS = ["Constant", "Identity"]
EINSUM_OPS = ["Einsum"]
ONE_HOT_OPS = ["OneHot"]
RNN_OPS = ["RNN"]
LSTM_OPS = ["LSTM"]
GRU_OPS = ["GRU"]
NMS_OPS = ["NonMaxSuppression"]
GATHER_OPS = ["Gather"]
LOG_SOFTMAX_OPS = ["LogSoftmax"]
SCATTER_ND_OPS = ["ScatterND"]

SKIP_OPS = [*CONST_OPS, "Shape", "ConstantOfShape"]
OPTIONAL_NULL_OPS = (
    PAD_OPS
    + DROP_OPS
    + ADD_OPS
    + MATH_OPS
    + SPLIT_OPS
    + ["Clip", "Reshape", "Flatten", "Unsqueeze", "Expand", "Squeeze", "Gather", "Transpose", "Slice", "ReduceMean"]
)
PRE_LAYER_OPS = ["Flatten", "Reshape", "Transpose"]
REDUCE_OPS = REDUCE_MAX_OPS + REDUCE_SUM_OPS + REDUCE_L2_OPS + REDUCE_SUM_SQUARE_OPS + ["ReduceMean"] + REDUCE_MIN_OPS
RNN_SEQ_OPS = [*RNN_OPS, *LSTM_OPS, *GRU_OPS]

SUPPORTED_OPS_UNION = (
    SKIP_OPS
    + CONV2D_OPS
    + DENSE_OPS
    + POOL_OPS
    + BN_OPS
    + ACTIVATION_OPS
    + ADD_OPS
    + SPLIT_OPS
    + CONCAT_OPS
    + LOGITS_OPS
    + RESIZE_OPS
    + SHUFFLE_OPS
    + SLICE_OPS
    + DROP_OPS
    + MUL_OPS
    + DIV_OPS
    + ALTERNATIVE_SOFTMAX_OPS
    + SUB_OPS
    + REDUCE_MAX_OPS
    + REDUCE_MIN_OPS
    + PAD_OPS
    + REDUCE_SUM_OPS
    + POW_OPS
    + LOGICAL_OPS
    + REDUCE_L2_OPS
    + INSTANCE_NORMALIZATION_OPS
    + OPTIONAL_NULL_OPS
    + EINSUM_OPS
    + REDUCE_SUM_SQUARE_OPS
    + TILE_OPS
    + LAYER_NORMALIZATION_OPS
    + RNN_OPS
    + LSTM_OPS
    + ONE_HOT_OPS
    + GATHER_OPS
    + LOG_SOFTMAX_OPS
    + MAX_OPS
    + NEG_OPS
    + GRU_OPS
    + SCATTER_ND_OPS
)

KERNEL_INITIALIZERS_NAMES = ["kernel", "weight", "_w_", "alpha"]
BIAS_INITIALIZERS_NAMES = ["bias", "_b_", "beta"]

BN_GAMMA_NAMES = ["weight", "gamma", "scale", "_w_"]
BN_BETA_NAMES = ["bias", "beta", "_b_"]
BN_MOVING_VAR_NAMES = ["moving_var", "running_var", "var"]
BN_MOVING_MEAN_NAMES = ["moving_mean", "running_mean", "mean"]

OPS_WITH_WEIGHTS = CONV2D_OPS + DENSE_OPS
OPS_WITH_WEIGHTS_PARAMS_ORDER = ["X", "kernel", "bias"]

BN_INPUT_ORDER = ["X", "gamma", "beta", "mean", "var"]
RESIZE_INPUT_ORDER = ["X", "roi", "scales", "sizes"]
UPSAMPLE_INPUT_ORDER = ["X", "scales"]
CLIP_INPUT_ORDER = ["X", "min", "max"]
CONV2D_INPUT_ORDER = ["X", "W", "B"]
RESHAPE_INPUT_ORDER = ["X", "shape"]
INSTANCE_NORMALIZATION_INPUT_ORDER = ["input", "scale", "B"]
LAYER_NORMALIZATION_INPUT_ORDER = ["input", "scale", "B"]
EINSUM_INPUT_ORDER = ["Inputs", "var"]
SQUEEZE_INPUT_ORDER = ["X", "axes"]
UNSQUEEZE_INPUT_ORDER = ["X", "axes"]
REDUCE_INPUT_ORDER = ["X", "axes"]
SLICE_INPUT_ORDER = ["data", "starts", "ends", "axes", "steps"]
SPLIT_INPUT_ORDER = ["X", "split"]
PRELU_INPUT_ORDER = ["X", "slope"]
MIN_INPUT_ORDER = ["X", "initializer"]
MAX_INPUT_ORDER = ["X", "initializer"]
RNN_INPUT_ORDER = ["X", "W", "R", "B", "sequence_lens", "initial_h"]
LSTM_INPUT_ORDER = ["X", "W", "R", "B", "sequence_lens", "initial_h", "initial_c"]
GRU_INPUT_ORDER = ["X", "W", "R", "B", "sequence_lens", "initial_h"]
PAD_INPUT_ORDER = ["data", "pads", "constant_value"]
TILE_INPUT_ORDER = ["X", "repeats"]
ADD_ORDER = ["X", "initializer"]
NMS_ORDER = ["max_output_boxes_per_class", "iou_threshold", "score_threshold"]
NMS_ONNX_KEY_TO_CONFIG_KEY = {
    "max_output_boxes_per_class": NMSProperties.MAX_PROPOSALS_PER_CLASS,
    "iou_threshold": NMSProperties.IOU_TH,
    "score_threshold": NMSProperties.SCORES_TH,
}

SLICE_ATTRS_ORDER = ["starts", "ends"]
DEFAULT_SPACE_TO_DEPTH_BLOCK_SIZE = 2

EINSUM_SUPPORTED_EQUATIONS = [
    "bmchw,bnmc->bmhwn",  # group conv
    "bchw,cj->bjhw",  # conv 1x1
    "byhwc,hkc->byhwk",  # matmul
    "byhwc,wkc->byhwk",  # matmul
    "nkctv,kvw->nctw",  # einsum gcn
]

AUTO_PAD_TO_PADDING_TYPE = {
    "NOTSET": TemporaryPaddingType.external_undecided,
    "SAME_UPPER": PaddingType.same_tensorflow,
    "SAME_LOWER": TemporaryPaddingType.same_lower,
    "VALID": PaddingType.valid,
}

DEFAULT_FORMAT_BY_RANK = {
    2: [Dims.BATCH, Dims.CHANNELS],
    3: [Dims.BATCH, Dims.WIDTH, Dims.CHANNELS],
    4: [Dims.BATCH, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH],
    5: [Dims.BATCH, Dims.CHANNELS, Dims.GROUPS, Dims.HEIGHT, Dims.WIDTH],
}

VERTEX_NAME_SEPARATOR = ":"


class ONNXGraphNode(NNGraphNode):
    def __init__(self, node_proto, graph, is_input_vertex=False):
        super().__init__(node_proto, graph)
        self.name = node_proto.name
        if is_input_vertex:
            self.op = INPUT_OP
            self.output = [self.name]
        else:
            self.op = node_proto.op_type
            self.input = node_proto.input
            self.output = node_proto.output

        self._is_spatial_1x1 = False

    @property
    def is_spatial_1x1(self):
        return self._is_spatial_1x1

    @is_spatial_1x1.setter
    def is_spatial_1x1(self, value):
        self._is_spatial_1x1 = value

    def get_pred_and_input_format(self):
        """
        Returns:
            pred: pred with output format that is not None or pred without output format if not exist otherwise or None
                if preds is empty or input format is from net_input_format dict.
            input_format: if exists, input format from net_input_format dict, else output format of pred, None otherwise
        """
        preds = list(self.graph.predecessors(self))

        if self.name in self._graph.net_input_format:
            return None, self._graph.net_input_format[self.name]

        if not preds:
            return None, None

        for pred_iter in preds:
            pred = pred_iter
            if pred.output_format:
                return pred, pred.output_format.copy()

        return preds[0], None

    def update_unsqueeze_output_format(self, pred, input_format):
        if self.is_unsqueeze_before_conv3d():
            self.output_format = [Dims.BATCH, Dims.CHANNELS, Dims.GROUPS, Dims.HEIGHT, Dims.WIDTH]
            return

        axes = self.get_axes_information()
        new_dims = []
        if pred and pred.is_spatial_flatten_reshape() and axes == [0]:
            # edge case - yolov5_c3tr_transformer_unsqueeze
            new_dims = [Dims.HEIGHT]
        elif self.is_successive_unsqueeze_flat_to_frame():
            input_format = [Dims.BATCH, Dims.CHANNELS]
            output_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
            new_dims = [Dims.WIDTH] if len(output_shape) == 3 else [Dims.HEIGHT, Dims.WIDTH]
        elif input_format:
            axes = [x if x != -1 else len(input_format) for x in axes]
            if self.is_unsqueeze_to_stack():
                new_dims = [Dims.STACK]
            elif len(axes) == 1:
                for dim in [Dims.BATCH, Dims.CHANNELS, Dims.WIDTH, Dims.GROUPS, Dims.HEIGHT]:
                    if dim not in input_format:
                        if dim != Dims.GROUPS or axes not in [[1], [3]] or Dims.HEIGHT in input_format:
                            # when axes = [1] or [3] and height isn't in input format we prefer height
                            new_dims = [dim]
                            break

        if input_format:
            output_format = input_format.copy()
            output_format[axes[0] : axes[0]] = new_dims
            self.output_format = output_format

    def update_reshape_output_format(self, input_format):
        is_features_to_heads, f_to_g_format = self.is_features_to_groups_reshape(input_format)
        is_groups_to_features, g_to_f_format = self.is_groups_to_features_reshape(input_format)
        is_group_norm, g_to_spatial_flatten_format = self.is_group_norm_reshape()
        is_windows_to_input, windows_format = self.is_windows_to_input_chain_end()

        if self.is_spatial_flatten_reshape():
            input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
            output_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
            if self._is_spatial_flatten_with_features_to_heads_reshape():
                self.output_format = [Dims.BATCH, Dims.GROUPS, Dims.CHANNELS, Dims.WIDTH]
            elif len(output_shape) == 2:
                self.output_format = (
                    [Dims.WIDTH, Dims.CHANNELS]
                    if input_format
                    == [
                        Dims.WIDTH,
                        Dims.BATCH,
                        Dims.GROUPS,
                        Dims.CHANNELS,
                    ]
                    else [Dims.CHANNELS, Dims.WIDTH]
                )
            elif (
                input_format
                and len(output_shape) == len(input_format) == 4
                and input_shape[input_format.index(Dims.CHANNELS)] == output_shape[input_format.index(Dims.CHANNELS)]
            ):
                # no need to change output format e.g [1, h, w, c] -> [1, 1, h*w, c]
                self.output_format = input_format
            elif input_format and Dims.HEIGHT in input_format:
                input_format.remove(Dims.HEIGHT)
                self.output_format = input_format
                if len(self.output_format) < len(output_shape) and Dims.BATCH not in input_format:
                    self.output_format.insert(0, Dims.BATCH)

        elif is_features_to_heads:
            if self._is_features_to_height_reshape(input_format):
                self.output_format = [Dims.BATCH, Dims.HEIGHT, Dims.CHANNELS, Dims.WIDTH]
            elif self._is_flatten_height_to_features_reshape(input_format):
                self.output_format = [Dims.BATCH, Dims.CHANNELS, Dims.WIDTH]
            elif self.is_features_to_stack_with_flat_height_reshape(input_format):
                self.output_format = [Dims.BATCH, Dims.WIDTH, Dims.STACK, Dims.GROUPS, Dims.CHANNELS]
            elif self.is_features_to_stack_with_flat_groups_reshape(input_format):
                self.output_format = [Dims.WIDTH, Dims.GROUPS, Dims.STACK, Dims.CHANNELS]
            else:
                self.output_format = f_to_g_format

        elif is_groups_to_features:
            self.output_format = g_to_f_format

        elif is_group_norm:
            chain = [FwdChainNode("InstanceNormalization"), FwdChainNode("Reshape")]
            second_reshape = look_for_node(self._graph, self, chain)
            second_reshape.output_format = (
                input_format if g_to_spatial_flatten_format is None else g_to_spatial_flatten_format
            )

        elif self.is_shuffle():
            chain = [FwdChainNode("Transpose"), FwdChainNode("Reshape")]
            nodes = get_all_nodes_in_chain(self.graph, self, chain)
            if nodes:
                transpose, last_reshape = nodes
                last_reshape.output_format = input_format.copy()
                perm = transpose.get_transpose_perm()
                channels_index = input_format.index(Dims.CHANNELS)
                extra_dims = [Dims.GROUPS, Dims.STACK]
                extra_dims_indices = 2 if len(perm) == 6 else (1 if len(perm) == 5 else 0)
                input_format[channels_index:channels_index] = extra_dims[:extra_dims_indices][::-1]
                self.output_format = input_format
            elif input_format == [Dims.WIDTH, Dims.CHANNELS]:
                # edge case of qwen2_vl_vision
                self.output_format = input_format.copy()

        elif self.is_squeeze_on_batch_reshape():
            self.output_format = input_format
            self.output_format.pop(self.output_format.index(Dims.BATCH))

        elif self.is_unsqueeze_on_batch_reshape():
            self.output_format = input_format.copy()
            axis = self._find_unsqueeze_axis()
            self.output_format.insert(axis, Dims.BATCH)

        elif (input_format == [Dims.BATCH, Dims.CHANNELS] and self.is_flat_to_frames_reshape()) or (
            input_format == [Dims.BATCH, Dims.HEIGHT, Dims.WIDTH] and self.is_nhw_to_nchw_reshape(input_format)
        ):
            self.output_format = [Dims.BATCH, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH]

        elif self.is_spatial_unflatten():
            if self._is_spatial_unflatten_with_g_to_f_reshape() or input_format == [Dims.CHANNELS, Dims.WIDTH]:
                # output format considers the size of output shape due to special case of rank3 output shape.
                self.output_format = (
                    [Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH]
                    if len(self.get_output_shapes(convert_to_nhwc=False)[0]) == 3
                    else [Dims.BATCH, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH]
                )
            elif self.is_reshape_transpose_expand_height_dim():
                # the transpose successor changes the output format to [batch, height, width, channels]
                self.output_format = [Dims.BATCH, Dims.WIDTH, Dims.HEIGHT, Dims.CHANNELS]
            elif input_format and Dims.WIDTH in input_format:
                self.output_format = input_format.copy()
                if Dims.HEIGHT in input_format:
                    self.output_format.remove(Dims.HEIGHT)
                self.output_format.insert(self.output_format.index(Dims.WIDTH), Dims.HEIGHT)
                if self.is_reshape_for_gather_features_slice():
                    self.output_format.pop(0)
                if Dims.GROUPS in self.output_format and len(self.get_output_shapes(False)[0]) == 6:
                    self.output_format.insert(5, Dims.STACK)

        elif self.is_flatten_reshape()[0]:
            self.output_format = [Dims.BATCH, Dims.CHANNELS]

        elif self.is_spatial_flatten_features_to_width():
            transpose = look_for_node(self._graph, self, [FwdChainNode(op="Transpose")])
            transpose.output_format = [Dims.BATCH, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH]

        elif self.is_torch_resize_nearest_reshape():
            self.output_format = input_format
            self.output_format.pop(input_format.index(Dims.HEIGHT))

        elif self.is_input_to_windows_chain_end():
            # split input to windowed attention
            self.output_format = [Dims.HEIGHT, Dims.WIDTH, Dims.CHANNELS]

        elif is_windows_to_input:
            # merge windowed attention to input
            self.output_format = windows_format

        elif self.is_groups_to_spatial_flatten():
            self.output_format = (
                [Dims.BATCH, Dims.CHANNELS, Dims.WIDTH]
                if len(self.get_output_shapes(convert_to_nhwc=False)[0]) == 3
                else [Dims.BATCH, Dims.HEIGHT, Dims.CHANNELS, Dims.WIDTH]
            )

        elif self.is_spatial_flatten_to_groups():
            self.output_format = [Dims.BATCH, Dims.GROUPS, Dims.HEIGHT, Dims.WIDTH, Dims.CHANNELS]

        elif self.is_spatial_flatten_and_groups_to_features():
            self.output_format = [Dims.BATCH, Dims.CHANNELS, Dims.GROUPS, Dims.HEIGHT, Dims.WIDTH]

        elif self.is_partial_groups_to_spatial_flatten():
            # includes the format of the followed transpose
            self.output_format = [Dims.BATCH, Dims.GROUPS, Dims.WIDTH, Dims.CHANNELS]

        elif self.is_spatial_unflatten() and input_format == [Dims.BATCH, Dims.WIDTH, Dims.CHANNELS]:
            self.output_format = [Dims.BATCH, Dims.HEIGHT, Dims.WIDTH, Dims.CHANNELS]

        elif self.is_flatten_height_stack_reshape() and input_format == [
            Dims.BATCH,
            Dims.GROUPS,
            Dims.HEIGHT,
            Dims.WIDTH,
            Dims.CHANNELS,
            Dims.STACK,
        ]:
            self.output_format = [Dims.BATCH, Dims.GROUPS, Dims.WIDTH, Dims.CHANNELS]

        elif self.is_features_to_stack():
            self.output_format = [Dims.BATCH, Dims.GROUPS, Dims.HEIGHT, Dims.WIDTH, Dims.CHANNELS, Dims.STACK]

        elif self.is_flatten_width_over_features_reshape():
            self.output_format = [Dims.BATCH, Dims.HEIGHT, Dims.WIDTH, Dims.CHANNELS]

    def update_gather_output_format(self, pred, input_format):
        if self.is_reducing_rank_gather(pred):
            self.output_format = input_format
            self.output_format.pop(self.get_axis())
        elif self.is_null_unsqueeze_gather(pred):
            # take the input format of the unsqueeze/transpose node
            self.output_format = pred.get_pred_and_input_format()[-1]
        elif self.is_unsqueeze_concat_gather():
            unsqueeze = look_for_node(self._graph, self, [BwdChainNode("Concat"), BwdChainNode("Unsqueeze")])
            self.output_format = unsqueeze.get_pred_and_input_format()[-1]
        elif input_format and len(input_format) == len(self.get_output_shapes(convert_to_nhwc=False)[0]):
            self.output_format = input_format
        elif self.is_input_gather_increasing_rank(pred):
            gather_out_rank = len(self.get_output_shapes(convert_to_nhwc=False)[0])
            self.output_format = DEFAULT_FORMAT_BY_RANK.get(gather_out_rank)

    def update_output_format(self):
        """
        Update the vertex output format according to the its pred(s) in toposort order.
        """
        if self.output_format:
            return

        preds = list(self.graph.predecessors(self))
        pred, input_format = self.get_pred_and_input_format()

        if self.op == "Unsqueeze":
            self.update_unsqueeze_output_format(pred, input_format)

        elif self.op == "Squeeze":
            if self.is_spatial_flatten_reshape():
                self.output_format = [Dims.BATCH, Dims.CHANNELS, Dims.WIDTH]
            elif self.is_squeeze_after_conv3d():
                self.output_format = [Dims.BATCH, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH]
            elif input_format:
                axes = self.get_squeeze_axes()
                if axes:
                    self.output_format = [x for i, x in enumerate(input_format) if i not in axes]
                    if Dims.GROUPS in self.output_format and Dims.CHANNELS not in self.output_format:
                        out_format = [dim if dim != Dims.GROUPS else Dims.CHANNELS for dim in self.output_format]
                        self.output_format = out_format

        elif self.op == "MatMul" and self.is_matmul_layer():
            self.output_format = self.get_matmul_layer_info()[-1]

        elif self.op == "Reshape":
            self.update_reshape_output_format(input_format)

        elif self.op == "Transpose":
            if not (
                self.is_transpose_after_spatial_flatten()
                or self.is_transposed_batch_norm()
                or self.is_transposed_batch_norm_second_transpose()
                or self.is_null_transpose()
                or self.is_spatial_unflatten()
            ) and (self.is_width_features_transpose() or self.is_height_width_transpose() or self.is_hc_transpose()):
                # No need to change the format if the layer changes the HN
                self.output_format = input_format
            else:
                perm = self.get_transpose_perm()
                if input_format and len(input_format) == len(perm):
                    self.output_format = [input_format[i] for i in perm]

        elif self.op == "Flatten":
            self.output_format = [Dims.BATCH, Dims.CHANNELS]

        elif self.op in REDUCE_OPS and input_format and not self.get_keepdims() and not self.is_null_not_keepdims():
            if self.is_reduce_max_after_group_conv_einsum():
                self.output_format = [Dims.BATCH, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH]
            else:
                indices = [i if i >= 0 else i + len(input_format) for i in self.get_axes_information()]
                self.output_format = [input_format[i] for i in range(len(input_format)) if i not in indices]
                if Dims.GROUPS in self.output_format and Dims.CHANNELS not in self.output_format:
                    self.output_format = [dim if dim != Dims.GROUPS else Dims.CHANNELS for dim in self.output_format]

        elif self.op in GATHER_OPS and pred:
            self.update_gather_output_format(pred, input_format)

        elif self.op in EW_OPS:
            self.output_format = input_format
            if self.is_ew_op_with_const_input():
                const_shape = self.get_const_input_values().shape
                if input_format and len(const_shape) - len(input_format) == 1 and Dims.BATCH not in input_format:
                    self.output_format = [Dims.BATCH, *input_format]
            elif len(preds) == 2:
                if (
                    preds[0].output_format
                    and preds[1].output_format
                    and preds[0].output_format != preds[1].output_format
                ):
                    if Dims.GROUPS in preds[0].output_format and Dims.GROUPS not in preds[1].output_format:
                        self.output_format = preds[0].output_format
                    elif Dims.GROUPS in preds[1].output_format and Dims.GROUPS not in preds[0].output_format:
                        self.output_format = preds[1].output_format

        elif self.op in LSTM_OPS + GRU_OPS + RNN_OPS:
            # from ONNX docs: [seq_length, num_directions, batch_size, hidden_size]
            self.output_format = [Dims.WIDTH, Dims.GROUPS, Dims.BATCH, Dims.CHANNELS]

        elif self.op in EINSUM_OPS and self.is_group_conv_einsum():
            self.output_format = [Dims.BATCH, Dims.GROUPS, Dims.HEIGHT, Dims.WIDTH, Dims.CHANNELS]

        elif (
            self.op in CONCAT_OPS
            and len(self.input_formats) == 2
            and [Dims.BATCH, Dims.HEIGHT, Dims.CHANNELS] in self.input_formats
            and [Dims.BATCH, Dims.WIDTH, Dims.CHANNELS] in self.input_formats
            and self.get_axis() == 1
        ):
            self.output_format = [Dims.BATCH, Dims.WIDTH, Dims.CHANNELS]

        else:
            self.output_format = input_format

        if self.output_format is not None:
            output_shapes = self.get_output_shapes(convert_to_nhwc=False)
            if not output_shapes or len(output_shapes[0]) != len(self.output_format):
                self.output_format = None

        elif self.is_null_reshape()[0]:
            self.output_format = input_format

    def _is_features_to_height_reshape(self, input_format):
        # specific chain from yolov8s
        if input_format == [Dims.BATCH, Dims.CHANNELS, Dims.WIDTH]:
            possible_chains = [
                [FwdChainNode(op="Transpose"), FwdChainNode(op="Softmax"), FwdChainNode(op="Conv")],
                [
                    FwdChainNode(op="Transpose"),
                    FwdChainNode(op="Softmax"),
                    FwdChainNode(op="Transpose"),
                    FwdChainNode(op="Conv"),
                ],
            ]
            consumed_vertices = get_all_nodes_from_possible_chains(self._graph, self, possible_chains)
            if consumed_vertices:
                if len(consumed_vertices) == 3:
                    return consumed_vertices[0].get_transpose_perm() == [0, 2, 1, 3]
                first_perm = consumed_vertices[0].get_transpose_perm() == [0, 3, 1, 2]
                second_perm = consumed_vertices[2].get_transpose_perm() == [0, 3, 2, 1]
                return first_perm and second_perm
        return False

    def _is_flatten_height_to_features_reshape(self, input_format):
        # specific chain from yolov8s
        if input_format == [Dims.BATCH, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH]:
            possible_chains = [
                [BwdChainNode(op="Conv"), BwdChainNode(op="Softmax"), BwdChainNode(op="Transpose")],
                [
                    BwdChainNode(op="Conv"),
                    BwdChainNode(op="Transpose"),
                    BwdChainNode(op="Softmax"),
                    BwdChainNode(op="Transpose"),
                ],
            ]
            consumed_vertices = get_all_nodes_from_possible_chains(self._graph, self, possible_chains)
            if consumed_vertices:
                if len(consumed_vertices) == 3:
                    return consumed_vertices[-1].get_transpose_perm() == [0, 2, 1, 3]
                last_transpose_perm = consumed_vertices[-1].get_transpose_perm() == [0, 3, 1, 2]
                mid_transpose_perm = consumed_vertices[1].get_transpose_perm() == [0, 3, 2, 1]
                return last_transpose_perm and mid_transpose_perm
        elif input_format == [Dims.BATCH, Dims.HEIGHT, Dims.WIDTH, Dims.CHANNELS]:
            input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
            output_shapes = self.get_output_shapes(convert_to_nhwc=False)
            return (
                input_shape[input_format.index(Dims.HEIGHT)] * input_shape[input_format.index(Dims.CHANNELS)]
                == output_shapes[0][-1]
            )
        return False

    def get_start_node_preds(self):
        preds = [x for x in self._graph.predecessors(self) if x.op not in CONST_OPS]
        if len(preds) > 1 and not (
            (self.op in ADD_OPS and len(preds) == 2) or (self.op in CONCAT_OPS and len(preds) >= 2)
        ):
            msg = (
                f"Start node {self.name} of type {self.op} has illegal number of input nodes({len(preds)}), which is "
                "not supported."
            )
            raise CantFindGraphStartError(msg)
        return preds

    def is_null_not_keepdims(self):
        unsqueeze = look_for_node(self._graph, self, [FwdChainNode(op="Unsqueeze")])
        return unsqueeze is not None and unsqueeze.get_axes_information() == self.get_axes_information()

    def is_global_pool(self):
        if self.op not in POOL_OPS:
            return False

        if self.op in ["AveragePool", "MaxPool"]:
            dims = self.get_kernel_shape()
            kernel_shape = [1, 1, dims[0], 1] if len(dims) == 1 else [1, dims[0], dims[1], 1]

            input_shapes = self.get_input_shapes()
            if input_shapes and kernel_shape[1] == input_shapes[0][1] and kernel_shape[2] == input_shapes[0][2]:
                return True

        elif self.op in ["GlobalAveragePool", "Mean", "ReduceMean", "Reshape"]:
            return True

        return False

    def get_attribute_by_name(self, name):
        return [attr for attr in self._info.attribute if name == attr.name]

    def get_net_input_shape(self):
        return [x.dim_value for x in self._info.type.tensor_type.shape.dim]

    def parse_raw_data(self, cast_to_int=False):
        if self.op == "Identity":
            parsed_data = self._graph.values_by_vertex_name[self.name][self._info.input[0]]
        else:
            parsed_data = numpy_helper.to_array(self._info.attribute[0].t)

        if cast_to_int:
            if not np.all(np.mod(parsed_data, 1) == 0):
                raise UnsupportedModelError(
                    f"Loss of precision when parsing raw data from constant vertex {self.name}. "
                    f"Expected data was probably floating point.",
                )
            parsed_data = parsed_data.astype(int)

        return parsed_data

    def _convert_shape(self, shape, output_format=None, convert_to_nhwc=True):
        if convert_to_nhwc:
            return self.convert_nchw_to_nhwc(shape, output_format)

        return self.keep_shape(shape)

    def get_output_shapes_by_info(self, convert_to_nhwc=True, validate_zero_dims=False):
        output_shapes = []

        if not hasattr(self._info, "output"):
            return []

        for output in self._info.output:
            if output in self._graph.output_shapes:
                nchw_shapes = self._graph.output_shapes[output]
                if len(nchw_shapes) == 1:
                    if len(list(self._graph.successors(self))) > 0:
                        num_of_outputs = len(
                            [succ for succ in self._graph.successors(self) if output in succ._info.input],
                        )
                    else:
                        num_of_outputs = 1
                    nchw_shapes = nchw_shapes * num_of_outputs
                    output_shapes.extend(
                        [
                            self._convert_shape(nchw_shape, self.output_format, convert_to_nhwc)
                            for nchw_shape in nchw_shapes
                        ],
                    )

        if validate_zero_dims and any(0 in shape for shape in output_shapes):
            msg = f"{self.op} layer {self.name} has unsupported zero-dimensioned output shapes."
            raise UnsupportedZeroDimensionShapeError(msg)

        return output_shapes

    def get_output_shapes(self, convert_to_nhwc=True, validate_zero_dims=False, **kwrags):
        output_shapes = []

        if self.name in self._graph.output_shapes:
            output_shapes = [
                self._convert_shape(shape, self.output_format, convert_to_nhwc)
                for shape in self._graph.output_shapes[self.name]
            ]
            num_succs = len(list(self._graph.successors(self)))
            if num_succs > 0:
                output_shapes = output_shapes * num_succs

        if len(output_shapes) == 0:
            output_shapes = self.get_output_shapes_by_info(convert_to_nhwc, validate_zero_dims)

        if validate_zero_dims and any(0 in shape for shape in output_shapes):
            msg = f"{self.op} layer {self.name} has unsupported zero-dimensioned output shapes."
            raise UnsupportedZeroDimensionShapeError(msg)

        return output_shapes

    def convert_nchw_to_nhwc(self, nchw_shape, output_format):
        if output_format is not None and len(nchw_shape) == len(output_format):
            dim_to_shape = dict(zip(output_format, nchw_shape))
            batch = dim_to_shape.get(Dims.BATCH, 1)
            height = dim_to_shape.get(Dims.HEIGHT, 1)
            width = dim_to_shape.get(Dims.WIDTH, 1)
            channels = int(np.prod([dim_to_shape.get(x, 1) for x in [Dims.CHANNELS, Dims.GROUPS, Dims.STACK]]))

            spatial_dims = [Dims.HEIGHT, Dims.WIDTH, Dims.GROUPS]
            if all(dim not in spatial_dims for dim in output_format):
                return [batch, channels]

            return [batch, height, width, channels]

        if len(nchw_shape) == 3:
            return [nchw_shape[0], 1, nchw_shape[2], nchw_shape[1]]
        if len(nchw_shape) == 4:
            return [nchw_shape[0], nchw_shape[2], nchw_shape[3], nchw_shape[1]]

        return nchw_shape

    def convert_batch_first(self, shape):
        return [1, *shape]

    def keep_shape(self, shape):
        return shape

    def get_input_shapes(self, convert_to_nhwc=True):
        if not self._graph.output_shapes or self.op in INPUT_OPS:
            return []
        preds = list(self._graph.predecessors(self))
        input_shapes = []
        for input_vertex in self._info.input:
            if input_vertex in set(self._graph.output_shapes) & set(self._graph.vertices_by_inp_key):
                input_shapes.append(
                    self._convert_shape(
                        self._graph.output_shapes[input_vertex][0],
                        self._graph.vertices_by_inp_key[input_vertex].output_format,
                        convert_to_nhwc,
                    ),
                )
            elif input_vertex in [x.name for x in preds]:
                pred = next(x for x in preds if input_vertex in x._info.output)
                pred_out_shapes = pred.get_output_shapes(
                    convert_to_nhwc=convert_to_nhwc,
                )
                if pred_out_shapes:
                    input_shapes.append([0])

        return input_shapes

    def get_dense_reshape_vertices(self):
        consumed_vertices = []
        cast_node = look_for_node(self._graph, self, [BwdChainNode(op="Cast")])
        if cast_node:
            consumed_vertices.append(cast_node)

        possible_chains = [[BwdChainNode(op="Constant")], [BwdChainNode(op="Cast"), BwdChainNode(op="Constant")]]
        const_node = get_node_from_possible_chains(self._graph, self, possible_chains)
        if const_node:
            consumed_vertices.append(const_node)
        return consumed_vertices

    def get_kernel(self, is_conv2d=True):
        vertex_params = self._graph.values_by_vertex_name.get(self.name, {})

        consumed_vertices = []
        if not vertex_params:
            for pred in self._graph.predecessors(self):
                # edge case: dense layer stores weights in preceding transpose/reshape vertices
                if self.op in DENSE_OPS:
                    if pred.op == "Transpose":
                        transpose_values = self._graph.values_by_vertex_name[pred.name]
                        if transpose_values:
                            vertex_params.update(transpose_values)
                            if self.op == "MatMul":
                                for key, val in vertex_params.items():
                                    if any(x in key for x in KERNEL_INITIALIZERS_NAMES):
                                        vertex_params[key] = val.transpose([1, 0])
                    elif pred.op == "Reshape":
                        consumed_vertices.append(pred)
                        consumed_vertices.extend(pred.get_dense_reshape_vertices())
                if pred.op in CONST_OPS:
                    const_shape = pred._info.attribute[0].t.dims
                    if len(const_shape) > 1:
                        consumed_vertices.append(pred)
                        vertex_params["kernel"] = pred.parse_raw_data()

        if "kernel" not in vertex_params:
            for pred in self._graph.predecessors(self):
                if pred.is_normalized_kernel():
                    vertex_params["kernel"] = pred.get_normalized_kernel()

        kernel = None
        weights = [val for key, val in vertex_params.items() if any(x in key for x in KERNEL_INITIALIZERS_NAMES)]
        if weights:
            kernel = weights[0]
            if self.input_format and Dims.GROUPS in self.input_format and is_conv2d and self.get_groups() > 1:
                # incase the input format contains groups, the kernel should be tiled
                groups_idx = self.input_format.index(Dims.GROUPS)
                groups = self.get_input_shapes(convert_to_nhwc=False)[0][groups_idx]
                reps = [1 if i != groups_idx else groups for i in range(len(kernel.shape))]
                kernel = np.tile(kernel, reps=reps)

        # handle 1D layers or rank3 MatMul (encoder)
        if len(np.shape(kernel)) == 3:
            insert_index = (
                0
                if self.op in DENSE_OPS
                and self.output_format
                and self.output_format[-2:] == [Dims.WIDTH, Dims.CHANNELS]
                else 2
            )
            new_kernel_shape = list(np.shape(kernel))
            new_kernel_shape.insert(insert_index, 1)
            kernel = kernel.reshape(new_kernel_shape)

        return kernel, consumed_vertices

    def is_normalized_kernel(self):
        if self.op == "Reshape":
            batch_norm = look_for_node(self._graph, self, [BwdChainNode(op="BatchNormalization")])
            if batch_norm:
                preds = list(self._graph.predecessors(batch_norm))
                if (
                    len(preds) == 3
                    and preds[0].op in CONST_OPS
                    and preds[1].op == preds[2].op == "ReduceMean"
                    and look_for_node(
                        self._graph,
                        preds[2],
                        [BwdChainNode(op="Mul"), BwdChainNode(op="Sub"), BwdChainNode(op="ReduceMean")],
                    )
                ):
                    return True

        return False

    def get_normalized_kernel(self):
        return self._graph.output_shapes[self._info.output[0] + "_value"]

    def get_bias(self):
        vertex_params = self._graph.values_by_vertex_name.get(self.name, None)
        if self.op in ADD_OPS:
            const_preds = [x for x in self.graph.predecessors(self) if x.op in CONST_OPS]
            if vertex_params and len(vertex_params) == 1:
                return np.asarray(next(iter(vertex_params.values()))).flatten(), []
            if const_preds:
                return const_preds[0].parse_raw_data().flatten().tolist(), const_preds

        # edge case: dense layer stores bias in successive add vertex
        consumed_vertices = []
        if (
            self.op in DENSE_OPS
            and len(list(self._graph.successors(self))) == 1
            and not any(any(bias_name in key for bias_name in BIAS_INITIALIZERS_NAMES) for key in vertex_params)
        ):
            vertex_params = {}
            for succ in self._graph.successors(self):
                if succ.op in ADD_OPS and len(list(self._graph.predecessors(succ))) == 1:
                    add_params = self._graph.values_by_vertex_name[succ.name]
                    if add_params:
                        bias = next(iter(add_params.values()))
                        vertex_params.update({"bias": bias})
                        consumed_vertices.append(succ)

        bias = None
        bias_values = [val for key, val in vertex_params.items() if any(x in key for x in BIAS_INITIALIZERS_NAMES)]
        if bias_values:
            bias = bias_values[0]

        if bias is None:
            return None, []

        return bias, consumed_vertices

    def get_strides(self):
        strides = self.get_attribute_by_name("strides")
        if strides:
            strides = strides[0].ints
            if len(strides) == 1:
                return [1, 1, strides[0], 1]
            return [1, strides[-2], strides[-1], 1]
        return [1, 1, 1, 1]

    def get_kernel_shape(self):
        return self.get_attribute_by_name("kernel_shape")[0].ints

    def get_leaky_alpha(self):
        return self.get_attribute_by_name("alpha")[0].f

    def get_prelu_slope(self):
        if self.op == "PRelu":
            var_initializer = {
                x: y
                for x, y in self._graph.values_by_vertex_name[self.name].items()
                if self.name in self._graph.values_by_vertex_name
            }
            index = PRELU_INPUT_ORDER.index("slope")
            param_input = self._info.input[index]
            if var_initializer:
                key = next(iter(var_initializer))
                return np.array(var_initializer[key], dtype=float).flatten(), []

            if param_input in self._graph.vertices_by_inp_key:
                # Try to get the values from constant inputs
                const = self._graph.vertices_by_inp_key[param_input]
                try:
                    return np.array(const.parse_raw_data(), dtype=float).flatten(), [const]
                # edge case when do_constant_folding=False, should_simplify=False
                except (TypeError, IndexError):
                    slopes = list(self._graph.values_by_vertex_name[const.name].values())
                    if len(slopes) > 0:
                        return np.array(slopes, dtype=float).flatten(), [const]
                    const_preds = list(self._graph.predecessors(const))
                    if len(const_preds) > 0 and const_preds[0].op == "Constant":
                        return np.array(const_preds[0].parse_raw_data(), dtype=float).flatten(), [const, const_preds[0]]

        elif self.op == "Relu":
            consumed_vertices = get_all_nodes_in_chain(
                self.graph,
                self,
                [FwdChainNode(op="Add"), BwdChainNode(op="Mul"), BwdChainNode(op="Relu"), BwdChainNode(op="Neg")],
            )
            mul = consumed_vertices[1]
            prelu_slope, mul_vertices, _ = mul.get_normalization_input_raw_values()
            consumed_vertices.extend(mul_vertices)
            if prelu_slope is not None:
                # invert the consumed vertices order to have proper original names order in the created layer
                return -prelu_slope.flatten(), consumed_vertices[::-1]

        raise UnsupportedActivationLayerError(f"Unable to find slopes for PRelu activation {self.name}.")

    def get_hardsigmoid_info(self):
        consumed_vertices = []
        if self.op == "HardSigmoid":
            alpha = self.get_attribute_by_name("alpha")
            alpha = alpha[0].f if alpha else 0.2

            beta = self.get_attribute_by_name("beta")
            # beta is 0.5 by default in onnx
            beta = beta[0].f if beta else 0.5
        else:
            sub = self
            nodes = get_all_nodes_in_chain(self.graph, self, [FwdChainNode("Div"), FwdChainNode("Clip")])
            consumed_vertices.extend(nodes)
            div = nodes[0]
            sub_value, sub_consumed_vertices, _ = sub.get_normalization_input_raw_values()
            div_value, div_consumed_vertices, _ = div.get_normalization_input_raw_values()
            consumed_vertices.extend(sub_consumed_vertices)
            consumed_vertices.extend(div_consumed_vertices)

            alpha = 1 / div_value
            beta = (-1) * sub_value / div_value

        return alpha, beta, consumed_vertices

    def is_hardsigmoid(self):
        if self.op == "HardSigmoid":
            return True

        if self.op not in SUB_OPS:
            return False

        shift, _, sub_index = self.get_normalization_input_raw_values()
        if shift is None or sub_index != 1 or shift.shape != ():
            return False

        nodes = get_all_nodes_in_chain(self.graph, self, [FwdChainNode("Div"), FwdChainNode("Clip")])
        if not nodes:
            return False

        div, clip = nodes
        scale, _, div_index = div.get_normalization_input_raw_values()
        if scale is None or div_index != 1 or scale.shape != ():
            return False

        min_value, max_value, _ = clip.get_clip_info()
        return min_value == 0 and max_value == 1

    def get_min_max_clip_info(self):
        consumed_vertex = []
        min_value = None
        max_value = None

        if self.op == "Min":
            max_value = self.get_initializer_or_constant_value(MIN_INPUT_ORDER)[0]
            if not max_value:
                return None, None, []
            max_node = look_for_node(self.graph, self, [FwdChainNode(op="Max")])
            if max_node:
                # structure of (min) -> (max) perform by clipping of [min_value, max_value]
                min_value = max_node.get_initializer_or_constant_value(MAX_INPUT_ORDER)[0]
                if not min_value:
                    return None, None, []
                consumed_vertex.append(max_node)
        elif self.op == "Max":
            min_value = self.get_initializer_or_constant_value(MAX_INPUT_ORDER)[0]
            if not min_value:
                return None, None, []
            min_node = look_for_node(self.graph, self, [FwdChainNode(op="Min")])
            if min_node:
                # structure of (max) -> (min) perform by clipping of [min_value, max_value]
                max_value = min_node.get_initializer_or_constant_value(MIN_INPUT_ORDER)[0]
                if not max_value:
                    return None, None, []
                consumed_vertex.append(min_node)
        elif self.op == "Clip":
            min_value, max_value, consumed_vertex = self.get_clip_info()

        # Representing clip with only max attr or standalone min as clip to [-np.inf, max_value]
        min_value = min_value if min_value is not None else -np.inf

        # Representing clip with only min attr or standalone max as clip to [min_value, np.inf]
        max_value = max_value if max_value is not None else np.inf

        return min_value, max_value, consumed_vertex

    def get_initializer_or_constant_value(self, order):
        node_value = self._graph.values_by_vertex_name.get(self.name, None)
        if node_value:
            node_value = node_value[self._info.input[order.index("initializer")]].flatten().tolist()
        else:
            node_value = look_for_node(self._graph, self, [BwdChainNode(op="Constant")])
            if not node_value:
                return None
            node_value = node_value.parse_raw_data().flatten().tolist()

        return node_value

    def get_activation_less_or_greater_values(self):
        consumed_vertices = []
        cast_node = look_for_node(self.graph, self, [FwdChainNode(op="Cast")])
        if cast_node:
            consumed_vertices.append(cast_node)

        activation_value, additional_consumed_vertices, index = self.get_normalization_input_raw_values()
        if activation_value is None:
            raise UnsupportedActivationLayerError(
                f"Unable to find values to compare for {self.op} activation {self.name}.",
            )

        activation_value = activation_value.flatten()
        if np.any(activation_value != activation_value[0]):
            raise UnsupportedActivationLayerError(
                f"{self.op} activation {self.name} with vector values is not supported.",
            )
        output_mapping = {
            (0, "Less"): ActivationType.greater,
            (0, "Greater"): ActivationType.less,
            (1, "Less"): ActivationType.less,
            (1, "Greater"): ActivationType.greater,
        }

        return output_mapping[(index, self.op)], activation_value, consumed_vertices + additional_consumed_vertices

    def get_groups(self):
        grattr = self.get_attribute_by_name("group")
        return grattr[0].i if grattr else 1

    def get_dilations(self):
        dilattr = self.get_attribute_by_name("dilations")
        if dilattr:
            if len(dilattr[0].ints) == 1:
                return [1, 1, dilattr[0].ints[0], 1]
            return [1, dilattr[0].ints[0], dilattr[0].ints[1], 1]
        return [1, 1, 1, 1]

    def get_vertex_padding(self):
        consumed_vertices = []
        padding_const_value = 0
        pads_attr = self.get_attribute_by_name("pads")
        auto_pad_attr = self.get_attribute_by_name("auto_pad")
        auto_pad_attr_value = auto_pad_attr[0].s.decode() if auto_pad_attr else ""

        if auto_pad_attr and auto_pad_attr_value != "NOTSET" and pads_attr:
            raise UnsupportedPaddingError(
                f"Vertex {self.name} has auto_pad set in conjunction with pads values, which is not supported.",
            )

        pads = [0, 0, 0, 0, 0, 0] if not pads_attr else pads_attr[0].ints
        if self.op in PAD_OPS:
            if not pads_attr:
                # attempt to get const value from external padding vertex
                var_initializers = self._graph.values_by_vertex_name.get(self.name, {})
                pad_mode = self.get_attribute_by_name("mode")

                constant_value_index = PAD_INPUT_ORDER.index("constant_value")
                if pad_mode and pad_mode[0].s.decode() == "constant" and len(self._info.input) > constant_value_index:
                    padding_const_value_key = self._info.input[constant_value_index]
                    if var_initializers.get(padding_const_value_key):
                        padding_const_value = var_initializers[padding_const_value_key].tolist()
                    padding_const_value_vertex = self._graph.vertices_by_inp_key.get(padding_const_value_key)
                    if padding_const_value_vertex is not None:
                        consumed_vertices.append(padding_const_value_vertex)
                        padding_const_value = padding_const_value_vertex.parse_raw_data()

                    if isinstance(padding_const_value, list):
                        padding_const_value = padding_const_value[0]

                pads_key = self._info.input[PAD_INPUT_ORDER.index("pads")]
                pads_const = self._graph.vertices_by_inp_key.get(pads_key)
                if pads_const is not None and pads_const.op in CONST_OPS:
                    consumed_vertices.append(pads_const)
                    pads = pads_const.parse_raw_data(cast_to_int=True).flatten().tolist()
                elif var_initializers:
                    pads = [int(x) for x in var_initializers[pads_key].tolist()]
                elif self._graph.output_shapes:
                    pads = [int(x) for x in self._graph.output_shapes[self._info.input[1] + "_value"]]
                    consumed_vertices.extend(self.get_all_shape_nodes())
                else:
                    raise UnsupportedPaddingError(f"Could not extract padding values for vertex {self.name}.")

            if len(pads) >= 8:
                # known issue following deprecation in onnx proto (https://github.com/NVIDIA/TensorRT/issues/195)
                pads = [pads[1], pads[2], pads[3], pads[5], pads[6], pads[7]]
            elif len(pads) == 4 and self.output_format == [Dims.BATCH, Dims.CHANNELS]:
                # onnx pads are [beg_b, beg_c, end_b, end_c], ignoring batch pads and take only channels pads
                pads = [pads[1], pads[3]]

        # get padding type from auto_pad attr, or infer from paddings vals
        if auto_pad_attr:
            padding = AUTO_PAD_TO_PADDING_TYPE[auto_pad_attr_value]
        else:
            padding = TemporaryPaddingType.external_undecided

        # fallback for supported external (undecided) zero paddings that can be used as valid
        if padding == TemporaryPaddingType.external_undecided and all(pad == 0 for pad in pads):
            padding = PaddingType.valid

        pads = self._get_onnx_to_hn_pads(pads)
        return padding, pads, padding_const_value, consumed_vertices

    def _get_onnx_to_hn_pads(self, pads):
        # pads order is [beg_f, beg_h, beg_w, end_f, end_h, end_w],
        # as https://github.com/onnx/onnx/blob/master/docs/Operators.md#pad
        if len(pads) == 6:
            hn_pads = [pads[1], pads[4], pads[2], pads[5], pads[0], pads[3]]
        elif len(pads) == 4:
            hn_pads = [pads[0], pads[2], pads[1], pads[3], 0, 0]
        elif len(pads) == 2:
            hn_pads = [0, pads[0], 0, pads[1], 0, 0]
            if self.output_format == [Dims.BATCH, Dims.CHANNELS, Dims.WIDTH]:
                hn_pads = [0, 0, pads[0], pads[1], 0, 0]
            elif self.output_format == [Dims.BATCH, Dims.CHANNELS]:
                hn_pads = [0, 0, 0, 0, *pads]
        # hn_pads is [beg_h, end_h, beg_w, end_w, beg_f, end_f]
        return hn_pads

    def is_null_padding(self):
        if self.op not in PAD_OPS:
            return False
        _, pads, _, _ = self.get_vertex_padding()
        return pads == [0, 0, 0, 0, 0, 0]

    def is_non_const_padding(self):
        mode_attr = self.get_attribute_by_name("mode")
        if mode_attr:
            mode_val = str(mode_attr[0].s.decode())
            return "constant" not in mode_val
        return False

    def get_bn_info(self):
        bn_info, consumed_vertices = self.get_batch_or_instance_normalization_info(input_order=BN_INPUT_ORDER)
        dummy_vertices, ext_weight, ext_bias = self.consume_dummy_bn_ops()
        consumed_vertices.extend(dummy_vertices)
        if ext_weight is not None and ext_bias is not None:
            bn_info["gamma"] = ext_weight
            bn_info["beta"] = ext_bias

        bn_vals = BatchNormValues(
            moving_mean=bn_info["mean"],
            moving_variance=bn_info["var"],
            gamma=bn_info["gamma"],
            beta=bn_info["beta"],
            epsilon=bn_info["epsilon"],
        )

        return bn_vals, consumed_vertices

    def get_batch_or_instance_normalization_info(self, input_order):
        info_dict = {}
        consumed_vertices = []
        values = self._graph.values_by_vertex_name[self.name]
        for param in input_order[1:]:
            index = input_order.index(param)
            param_input = self._info.input[index]
            if param_input in values:
                # Try to get the values from variable initializer inputs
                info_dict[param] = values[param_input]
            elif param_input in self._graph.vertices_by_inp_key:
                # Try to get the values from constant inputs
                const = self._graph.vertices_by_inp_key[param_input]
                info_dict[param] = const.parse_raw_data()
                consumed_vertices.append(const)

        # tiles the info over channels when the layer has groups
        if self.output_format and Dims.GROUPS in self.output_format:
            groups_axis = self.output_format.index(Dims.GROUPS)
            for key in info_dict:
                info_dict[key] = np.tile(
                    info_dict[key],
                    self.get_input_shapes(convert_to_nhwc=False)[0][groups_axis],
                )
        info_dict["epsilon"] = self.get_attribute_by_name("epsilon")[0].f
        return info_dict, consumed_vertices

    def get_tile_repeats(self):
        repeats = None
        consumed_vertices = []
        repeats_input = self._info.input[1]
        values = self._graph.values_by_vertex_name.get(self.name, None)

        if values is not None and repeats_input in values:
            # Try to get the values from variable initializer inputs
            repeats = values[repeats_input]
        elif repeats_input in self._graph.vertices_by_inp_key:
            # Try to get the values from constant inputs
            const = self._graph.vertices_by_inp_key[repeats_input]
            repeats = const.parse_raw_data()
            consumed_vertices.append(const)

        return repeats, consumed_vertices

    def get_concat_info(self, const_layers=None):
        group_sizes = None
        axis = DEFAULT_CONCAT_AXIS

        if not const_layers:
            if (
                self.is_spatial_w_concat()
                or self.is_concat_of_pos_embeds()
                or self.is_concat_of_transpose_width_features()
            ):
                axis = ConcatAxis.spatial_w
            elif self.is_spatial_h_concat():
                axis = ConcatAxis.spatial_h

        onnx_axis = self.get_axis()
        if self.output_format and onnx_axis < len(self.output_format):
            concat_dim = self.output_format[onnx_axis]
            axis, group_sizes = self.get_concat_info_from_output_format(concat_dim, self.output_format, onnx_axis)

        output_shapes = self.get_output_shapes(validate_zero_dims=True)
        return axis, group_sizes, output_shapes

    def get_concat_info_from_output_format(self, concat_dim, output_format, onnx_axis=None):
        group_sizes = None
        dims_to_axis = {
            Dims.CHANNELS: ConcatAxis.features,
            Dims.GROUPS: ConcatAxis.features,
            Dims.STACK: ConcatAxis.features,
            Dims.HEIGHT: ConcatAxis.spatial_h,
            Dims.WIDTH: ConcatAxis.spatial_w,
        }
        axis = dims_to_axis.get(concat_dim)

        if axis is None:
            raise UnsupportedConcatLayerError(f"Unsupported concat over axis {concat_dim}")

        onnx_input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        if Dims.GROUPS in output_format and concat_dim in [Dims.CHANNELS, Dims.STACK]:
            group_sizes = [1] * onnx_input_shape[output_format.index(Dims.GROUPS)]
            if concat_dim == Dims.STACK and onnx_axis == -1:
                group_sizes *= onnx_input_shape[output_format.index(Dims.CHANNELS)]

        return axis, group_sizes

    def get_rnn_info(self):
        info_dict = {}
        consumed_vertices = []

        activations_attr = self.get_attribute_by_name("activations")
        if activations_attr:
            activation = activations_attr[0].strings[0].decode()
            if activation != "Tanh":
                raise UnsupportedRNNLayerError(f"activation type {activation} is not supported")

        clip_attr = self.get_attribute_by_name("clip")
        if clip_attr:
            raise UnsupportedRNNLayerError("clip is not supported")

        direction_attr = self.get_attribute_by_name("direction")
        if direction_attr:
            direction = direction_attr[0].s.decode()
            if direction != "forward":
                raise UnsupportedRNNLayerError(
                    f"direction {direction} is not supported, only forward is currently supported",
                )

        values = self._graph.values_by_vertex_name[self.name]
        for param in RNN_INPUT_ORDER[1:]:
            index = RNN_INPUT_ORDER.index(param)
            param_input = self._info.input[index]
            if param_input in values:
                # Try to get the values from variable initializer inputs
                info_dict[param] = values[param_input]
            elif param_input in self._graph.vertices_by_inp_key:
                # Try to get the values from constant inputs
                const = self._graph.vertices_by_inp_key[param_input]
                info_dict[param] = const.parse_raw_data()
                consumed_vertices.append(const)

        squeeze = look_for_node(self._graph, self, [FwdChainNode(op="Squeeze")])
        if squeeze is not None:
            consumed_vertices.append(squeeze)

        kernel = np.transpose(np.expand_dims(info_dict["W"], axis=0), [0, 1, 3, 2])
        recurrent_kernel = np.transpose(np.expand_dims(info_dict["R"], axis=0), [0, 1, 3, 2])
        if info_dict.get("B") is not None:
            bias, recurrent_bias = np.split(np.squeeze(info_dict["B"]), 2)
        else:
            bias, recurrent_bias = None, None

        initial_h = np.expand_dims(info_dict["initial_h"], axis=0) if info_dict.get("initial_h") is not None else None

        return kernel, recurrent_kernel, bias, recurrent_bias, initial_h, consumed_vertices

    def get_lstm_info(self):
        info_dict = {}
        consumed_vertices = []

        activations_attr = self.get_attribute_by_name("activations")
        if activations_attr:
            activations = [string.decode() for string in activations_attr[0].strings]
            if activations != ["Sigmoid", "Tanh", "Tanh"]:
                raise UnsupportedLSTMLayerError(f"activation types {activations} are not supported")

        clip_attr = self.get_attribute_by_name("clip")
        if clip_attr:
            raise UnsupportedLSTMLayerError("clip is not supported")

        direction_attr = self.get_attribute_by_name("direction")
        if direction_attr:
            direction = direction_attr[0].s.decode()
            supported_directions = ["forward", "bidirectional"]
            if direction not in supported_directions:
                raise UnsupportedLSTMLayerError(
                    f"direction {direction} is not supported, only {', '.join(supported_directions)} are currently "
                    f"supported",
                )
            info_dict["direction"] = direction
        else:
            info_dict["direction"] = "forward"

        input_forget_attr = self.get_attribute_by_name("input_forget")
        if input_forget_attr:
            raise UnsupportedLSTMLayerError("input_forget is not supported")

        values = self._graph.values_by_vertex_name[self.name]
        for param in LSTM_INPUT_ORDER[1:]:
            index = LSTM_INPUT_ORDER.index(param)
            param_input = self._info.input[index]
            if param_input in values:
                # Try to get the values from variable initializer inputs
                info_dict[param] = values[param_input]
            elif param_input in self._graph.vertices_by_inp_key:
                # Try to get the values from constant inputs
                const = self._graph.vertices_by_inp_key[param_input]
                info_dict[param] = const.parse_raw_data()
                consumed_vertices.append(const)

        squeeze = look_for_node(self._graph, self, [FwdChainNode(op="Squeeze")])
        if squeeze is not None:
            consumed_vertices.append(squeeze)

        forward_params = self._get_lstm_params(0, info_dict)
        if info_dict["direction"] == "bidirectional":
            backward_params = self._get_lstm_params(1, info_dict)
            nodes = get_all_nodes_in_chain(
                self._graph,
                self,
                [FwdChainNode(op="Transpose"), FwdChainNode(op="Reshape")],
            )
            if nodes is None:
                msg = "Expected to find Transpose and Reshape nodes after bidirectional LSTM node"
                raise UnsupportedLSTMLayerError(msg)
            consumed_vertices.extend(nodes)
        else:
            backward_params = (None,) * 6

        return forward_params, backward_params, info_dict["direction"], consumed_vertices

    @staticmethod
    def _get_lstm_params(index, info_dict):
        initial_h = None
        initial_c = None
        bias = None
        recurrent_bias = None

        kernel = np.transpose(np.expand_dims(info_dict["W"][index], axis=(0, 1)), [0, 1, 3, 2])
        recurrent_kernel = np.transpose(np.expand_dims(info_dict["R"][index], axis=(0, 1)), [0, 1, 3, 2])
        if info_dict.get("B") is not None:
            bias, recurrent_bias = np.split(info_dict["B"][index], 2)
        if info_dict.get("initial_h") is not None and not (
            info_dict["initial_h"].shape == (1,) and info_dict["initial_h"][0] == 0
        ):
            initial_h = np.expand_dims(info_dict["initial_h"][index], axis=(0, 1))
        if info_dict.get("initial_c") is not None and not (
            info_dict["initial_c"].shape == (1,) and info_dict["initial_c"][0] == 0
        ):
            initial_c = np.expand_dims(info_dict["initial_c"][index], axis=(0, 1))

        return kernel, recurrent_kernel, bias, recurrent_bias, initial_h, initial_c

    def get_gru_info(self):
        attr_to_value = {}
        consumed_vertices = []
        unsupported_attrs_to_default_value = {
            "activation_alpha": 0.01,
            "activation_beta": None,
            "clip": None,
            "direction": "forward",
            "layout": 0,
        }
        if any(
            attr.name in unsupported_attrs_to_default_value
            and self.get_attribute_by_name(attr.name)[0].i != unsupported_attrs_to_default_value.get(attr.name)
            for attr in self._info.attribute
        ):
            raise UnsupportedGRULayerError(f"GRU layer {self.name} contains unsupported attributes")

        hidden_size = self.get_attribute_by_name("hidden_size")[0].i
        linear_before_reset = self.get_attribute_by_name("linear_before_reset")[0].i
        # extracts value of each attribute of the vertex
        self._extract_values_from_vertex(GRU_INPUT_ORDER, attr_to_value, consumed_vertices)
        # the gru can be followed by a squeeze node
        squeeze = look_for_node(self._graph, self, [FwdChainNode(op="Squeeze")])
        if squeeze is not None:
            consumed_vertices.append(squeeze)

        return (
            hidden_size,
            linear_before_reset,
            attr_to_value.get("W"),
            attr_to_value.get("R"),
            attr_to_value.get("B"),
            attr_to_value.get("sequence_lens"),
            attr_to_value.get("initial_h"),
            consumed_vertices,
        )

    def _extract_values_from_vertex(self, input_order, attr_to_value, consumed_vertices):
        input_nodes = [vertex.name for vertex in self._graph if vertex.op in INPUT_OPS]
        values = self._graph.values_by_vertex_name[self.name]
        for param in input_order:
            index = input_order.index(param)
            param_input = self._info.input[index]
            if param_input in values:
                # Try to get the values from variable initializer inputs
                attr_to_value[param] = values[param_input]
            elif param_input in self._graph.vertices_by_inp_key and param_input not in input_nodes:
                # Try to get the values from constant inputs
                vertex = self._graph.vertices_by_inp_key[param_input]
                if vertex.op in CONST_OPS:
                    attr_to_value[param] = vertex.parse_raw_data()
                    consumed_vertices.append(vertex)

    def get_instance_normalization_info(self):
        nodes = None
        if self.op == "Reshape":
            nodes = get_all_nodes_in_chain(
                self.graph,
                self,
                [FwdChainNode("InstanceNormalization"), FwdChainNode("Reshape")],
            )

        inst_norm = self if nodes is None else nodes[0]
        info, consumed_vertices = inst_norm.get_batch_or_instance_normalization_info(
            input_order=INSTANCE_NORMALIZATION_INPUT_ORDER,
        )

        if nodes is not None:
            consumed_vertices.extend(nodes)

        groups = inst_norm.get_output_shapes(convert_to_nhwc=False)[0][1]
        info["groups"] = groups
        info["axes"] = [1, 2, 3]
        return info, consumed_vertices

    def get_layer_normalization_info(self):
        rms_norm = False
        if self.op in LAYER_NORMALIZATION_OPS:
            info_dict, consumed_vertices = self.get_batch_or_instance_normalization_info(
                input_order=LAYER_NORMALIZATION_INPUT_ORDER,
            )
            start_axis = self.get_axis()
            input_shapes = self.get_input_shapes(convert_to_nhwc=False)
            axes = None
            if self.output_format and self.output_format != [Dims.BATCH, Dims.CHANNELS]:
                start_axis = start_axis + len(input_shapes[0]) if start_axis < 0 else start_axis
                axes = list(range(start_axis, len(input_shapes[0])))  # return all axes from start_axis to |rank|
                axes = self._convert_axes_to_nhwc(axes)
            elif input_shapes:
                shape_rank = len(input_shapes[0])
                if shape_rank == 2 and start_axis in [-1, 1]:
                    axes = [3]
                if shape_rank == 3:
                    # Assuming format is NCW (rank 3 after transpose width-features)
                    if start_axis in [-1, 2]:
                        axes = [3]  # features
                    if start_axis in [-2, 1]:
                        axes = [2, 3]  # width and features
                elif shape_rank == 4:
                    if start_axis in [-1, 3]:
                        axes = [2]  # width
                    if start_axis in [-2, 2]:
                        axes = [1, 2]  # height and width
                    if start_axis in [-3, 1]:
                        axes = [1, 2, 3]  # height, width and features

            if axes is None:
                raise UnsupportedLayerNormLayerError(
                    f"Layer {self.name} of type {self.op} with input shapes "
                    f"{input_shapes} and axis {start_axis} is not supported",
                )

            info_dict["axes"] = axes
        else:
            # can be either decomposed RMSNorm or LayerNorm
            rms_chains = [
                [
                    FwdChainNode("ReduceMean"),
                    FwdChainNode("Add"),
                    FwdChainNode("Sqrt"),
                    FwdChainNode("Div"),
                    FwdChainNode("Mul"),
                ],
                [
                    FwdChainNode("ReduceMean"),
                    FwdChainNode("Add"),
                    FwdChainNode("Sqrt"),
                    FwdChainNode("Reciprocal"),
                    FwdChainNode("Mul"),
                ],
            ]
            consumed_vertices = get_all_nodes_from_possible_chains(self.graph, self, rms_chains)
            if consumed_vertices:
                rms_norm = True
            else:
                consumed_vertices = []
                if self.op in SUB_OPS:
                    reduce_mean = look_for_node(self.graph, self, [BwdChainNode(op="ReduceMean")])
                    consumed_vertices.extend([reduce_mean])
                elif self.op == "ReduceMean":
                    reduce_mean = self

                layer_norm_chains = [
                    [
                        FwdChainNode("Sub"),
                        FwdChainNode("Pow"),
                        FwdChainNode("ReduceMean"),
                        FwdChainNode("Add"),
                        FwdChainNode("Sqrt"),
                        FwdChainNode("Div"),
                    ],
                    [
                        FwdChainNode("Sub"),
                        FwdChainNode("Mul"),
                        FwdChainNode("ReduceMean"),
                        FwdChainNode("Add"),
                        FwdChainNode("Sqrt"),
                        FwdChainNode("Div"),
                    ],
                ]
                consumed_vertices.extend(get_all_nodes_from_possible_chains(self.graph, reduce_mean, layer_norm_chains))
                # if sub is self, removes it from the consumed vertices
                consumed_vertices = consumed_vertices[1:] if reduce_mean != self else consumed_vertices

            add = next(node for node in consumed_vertices if node.op == "Add")
            epsilon, epsilon_vertices, _ = add.get_normalization_input_raw_values()
            consumed_vertices.extend(epsilon_vertices)
            info_dict = {}
            info_dict["axes"] = [3]
            info_dict["B"] = np.array(0)
            info_dict["epsilon"] = epsilon
            info_dict["scale"] = np.array(1)
            # updates groups if needed
            if self.output_format and Dims.GROUPS in self.output_format:
                info_dict["groups"] = self.get_output_shapes(convert_to_nhwc=False)[0][
                    self.output_format.index(Dims.GROUPS)
                ]
        return info_dict, rms_norm, consumed_vertices

    def get_l2_normalization_info(self):
        if self.op == "Abs":
            chain = [FwdChainNode(op) for op in ["Pow", "ReduceSum", "Pow", "Clip", "Expand", "Div"]]
            consumed_vertices = get_all_nodes_in_chain(self.graph, self, chain)
            axis = consumed_vertices[1].get_axes_information(convert_to_nhwc=True)
        else:
            axis = self.get_axes_information(convert_to_nhwc=True)
            ops = ["Unsqueeze", "Div"] if self.get_attribute_by_name("keepdims")[0].i == 0 else ["Div"]
            chain = [FwdChainNode(op) for op in ops]
            consumed_vertices = get_all_nodes_in_chain(self.graph, self, chain)

        if len(axis) != 1:
            raise UnsupportedL2NormLayerError("L2 normalization layer with multiple axes is not supported")

        scale = np.sqrt(1 / self.get_input_shapes()[0][axis[0]])
        return axis, scale, consumed_vertices

    def consume_dummy_bn_ops(self):
        consumed_vertices = []
        weight = None
        bias = None
        possible_value_chains = [[BwdChainNode(op="Constant")], [BwdChainNode(op="Unsqueeze")]]
        dummy_mul = look_for_node(self._graph, self, [FwdChainNode(op="Mul")])
        if dummy_mul:
            dummy_mul_vals_node = get_node_from_possible_chains(self._graph, dummy_mul, possible_value_chains)
            if dummy_mul_vals_node:
                consumed_vertices.append(dummy_mul)
                consumed_vertices.append(dummy_mul_vals_node)
                if dummy_mul_vals_node.op in CONST_OPS:
                    weight = dummy_mul_vals_node.parse_raw_data().flatten().tolist()
                elif dummy_mul_vals_node.name in self._graph.values_by_vertex_name:
                    for key, val in self._graph.values_by_vertex_name[dummy_mul_vals_node.name].items():
                        if any(x in key for x in BN_GAMMA_NAMES):
                            weight = val
                dummy_add = look_for_node(self._graph, dummy_mul, [FwdChainNode(op="Add")])
                if dummy_add:
                    consumed_vertices.append(dummy_add)
                    dummy_add_vals_node = get_node_from_possible_chains(self._graph, dummy_add, possible_value_chains)
                    if dummy_add_vals_node:
                        consumed_vertices.append(dummy_add_vals_node)
                        if dummy_add_vals_node.op in CONST_OPS:
                            bias = dummy_add_vals_node.parse_raw_data().flatten().tolist()
                        elif dummy_add_vals_node.name in self._graph.values_by_vertex_name:
                            for key, val in self._graph.values_by_vertex_name[dummy_add_vals_node.name].items():
                                if any(x in key for x in BN_BETA_NAMES):
                                    bias = val

        return consumed_vertices, weight, bias

    def is_shape_op(self):
        return self.op == "Shape"

    def is_null_clip(self):
        if self.op != "Clip":
            return False

        min_clip_value, max_clip_value, _ = self.get_clip_info()

        return max_clip_value is None and min_clip_value is None

    def is_clip_to_positive(self):
        if self.op != "Clip":
            return False, []

        min_clip_value, max_clip_value, consumed_vertices = self.get_clip_info()

        if max_clip_value or not min_clip_value:
            return False, []

        return 0 < min_clip_value < 1e-7, consumed_vertices

    def get_clip_info(self):
        consumed_vertices = []
        clip_info = {}
        for x in ["min", "max"]:
            x_attr = self.get_attribute_by_name(x)
            if x_attr:
                # Try to get the values from vertex attributes
                clip_info[x] = x_attr[0].f
            elif len(self._info.input) == 3:
                x_input = self._info.input[CLIP_INPUT_ORDER.index(x)]
                possible_values = self._graph.values_by_vertex_name[self.name]
                if x_input in possible_values:
                    # Try to get the values from variable initializer inputs
                    clip_info[x] = (
                        possible_values[x_input].tolist()
                        if isinstance(possible_values[x_input], np.ndarray)
                        else possible_values[x_input]
                    )
                    if isinstance(clip_info[x], list) and len(clip_info[x]) == 1:
                        # 1-d array, extracting the value
                        clip_info[x] = clip_info[x][0]
                elif x_input in self._graph.vertices_by_inp_key:
                    # Try to get the values from constant or cast inputs
                    x_const = self._graph.vertices_by_inp_key[x_input]
                    if x_const.op in ["Constant", "Identity"]:
                        node_to_parse = x_const
                    else:
                        node_to_parse = look_for_node(self._graph, x_const, [BwdChainNode(op="Const")])
                    clip_info[x] = node_to_parse.parse_raw_data().item()
                    consumed_vertices.append(x_const)

        return clip_info.get("min"), clip_info.get("max"), consumed_vertices

    def is_biased_delta_activation(self):
        if self.op == "Sign":
            consumed_vertices = get_all_nodes_in_chain(
                self._graph,
                self,
                [FwdChainNode(op="Abs"), FwdChainNode(op="Mul")],
            )
        elif self.op == "Abs":
            consumed_vertices = get_all_nodes_in_chain(
                self._graph,
                self,
                [FwdChainNode(op="Sign"), FwdChainNode(op="Mul")],
            )
        else:
            return False, None, []

        if not consumed_vertices or len(consumed_vertices) != 2:
            return False, None, []

        mul = consumed_vertices[-1]
        bias_value, additional_consumed_vertices, _ = mul.get_normalization_input_raw_values()
        consumed_vertices.extend(additional_consumed_vertices)

        return True, bias_value.tolist(), consumed_vertices

    def is_threshold_activation(self):
        consumed_vertices = get_all_nodes_in_chain(self._graph, self, [FwdChainNode(op="Cast"), FwdChainNode(op="Mul")])
        if not consumed_vertices or len(consumed_vertices) != 2:
            return False, None, []

        thresh_value, additional_consumed_vertices, _ = self.get_normalization_input_raw_values()
        if thresh_value is None:
            return False, None, []

        common_stem = self._has_common_stem(self, consumed_vertices[-1])
        if not common_stem:
            return False, None, []

        return True, thresh_value.tolist(), consumed_vertices + additional_consumed_vertices

    def is_silu_activation(self):
        if self.op == "Sigmoid":
            ew_mul = look_for_node(self._graph, self, [FwdChainNode(op="Mul")])
            if ew_mul and ew_mul.is_ew_mult():
                common_stem = self._has_common_stem(self, ew_mul)
                if common_stem:
                    return True, [ew_mul]
        elif self.op == "Mul":
            sigmoid = look_for_node(self._graph, self, [BwdChainNode(op="Sigmoid")])
            if sigmoid:
                common_stem = self._has_common_stem(self, sigmoid)
                return common_stem and self.is_ew_mult(), []

        return False, []

    def is_softsign_activation(self):
        if self.op == "Softsign":
            return True, []

        if self.op == "Abs":
            nodes = get_all_nodes_from_possible_chains(
                self._graph,
                self,
                [[FwdChainNode(op="Add"), FwdChainNode(op="Div")]],
            )
            if nodes is not None:
                add_node = nodes[0]
                ew_div_node = nodes[1]

                const_val = add_node.get_normalization_input_raw_values()[0]
                if not const_val:
                    return False, []

                add_val_cond = const_val.tolist() == 1.0
                common_stem = self._has_common_stem(self, ew_div_node)
                if common_stem and add_val_cond:
                    return True, nodes

        elif self.op == "Div":
            abs_chain = look_for_node(self._graph, self, [BwdChainNode(op="Add"), BwdChainNode(op="Abs")])
            if abs_chain is not None:
                return abs_chain.is_softsign_activation()

        return False, []

    def is_swish_activation(self):
        if self.op != "Mul":
            return False, None, []

        non_const_preds = [x for x in self._graph.predecessors(self) if x.op not in CONST_OPS]
        if len(non_const_preds) == 1:
            mul_node = self
            sigmoid_node = look_for_node(self._graph, self, [FwdChainNode(op="Sigmoid")])
            ew_mul_node = look_for_node(self._graph, self, [FwdChainNode(op="Sigmoid"), FwdChainNode(op="Mul")])
        elif len(non_const_preds) == 2:
            mul_node = look_for_node(self._graph, self, [BwdChainNode(op="Sigmoid"), BwdChainNode(op="Mul")])
            sigmoid_node = look_for_node(self._graph, self, [BwdChainNode(op="Sigmoid")])
            ew_mul_node = self
        else:
            return False, None, []

        if not mul_node or not sigmoid_node or not ew_mul_node:
            return False, None, []

        ew_mul_non_const_preds = [x for x in self._graph.predecessors(ew_mul_node) if x.op not in CONST_OPS]
        mul_non_const_preds = [x for x in self._graph.predecessors(mul_node) if x.op not in CONST_OPS]
        if len(ew_mul_non_const_preds) != 2 or len(mul_non_const_preds) != 1:
            return False, None, []

        if (
            len(list(self._graph.successors(mul_node))) != 1
            or len(list(self._graph.predecessors(sigmoid_node))) != 1
            or len(list(self._graph.successors(sigmoid_node))) != 1
        ):
            return False, None, []

        beta, _, _ = mul_node.get_normalization_input_raw_values()
        if beta is None or beta.shape not in [(), (1,)]:
            return False, None, []

        consumed_vertices = [sigmoid_node, ew_mul_node] if self == mul_node else []
        beta = beta[0] if beta.shape == (1,) else beta
        return True, beta.tolist(), consumed_vertices

    def is_swish_activation_ew_mul(self):
        non_const_preds = [x for x in self._graph.predecessors(self) if x.op not in CONST_OPS]
        return self.op == "Mul" and len(non_const_preds) == 2 and self.is_swish_activation()[0]

    def is_gelu_activation(self):
        if self.op == "Div":
            return_vertices = []
            gelu_vertices = get_all_nodes_in_chain(
                self._graph,
                self,
                [FwdChainNode(op="Erf"), FwdChainNode(op="Add"), FwdChainNode(op="Mul"), FwdChainNode(op="Mul")],
            )

            if gelu_vertices:
                return_vertices.extend(gelu_vertices)

            div_value = None
            div_const = look_for_node(self._graph, self, [BwdChainNode(op="Constant")])
            div_params = self._graph.values_by_vertex_name.get(self.name, None)
            if div_const:
                div_value = div_const.parse_raw_data().flatten()
                return_vertices.append(div_const)
            elif div_params:
                div_value = list(self.graph.values_by_vertex_name[self.name].values())
                if len(div_value) == 1:
                    div_value = div_value[0].tolist()
                else:
                    raise UnsupportedModelError(f"Error occurred in layer {self.name} cannot infer div value")

            div_cond = (div_const or div_params) and np.all(np.isclose(div_value, np.sqrt(2.0)))
            if div_cond and gelu_vertices:
                return True, return_vertices

        elif self.op == "Mul":
            div = look_for_node(
                self._graph,
                self,
                [BwdChainNode(op="Add"), BwdChainNode(op="Erf"), BwdChainNode(op="Div")],
            )
            if div:
                common_stem = self._has_common_stem(self, div)
                return common_stem and self.is_ew_mult(), []

        return False, []

    def is_mish_activation(self):
        if self.op == "Softplus":
            mish_vertices = get_all_nodes_in_chain(self._graph, self, [FwdChainNode(op="Tanh"), FwdChainNode(op="Mul")])
            if mish_vertices is not None:
                return True, mish_vertices
        elif self.op == "Mul":
            softplus = look_for_node(self._graph, self, [BwdChainNode(op="Tanh"), BwdChainNode(op="Softplus")])
            if softplus:
                common_stem = self._has_common_stem(self, softplus)
                return common_stem and self.is_ew_mult(), []

        return False, []

    def is_simple_hardswish_activation(self):
        mul = None
        hardsigmoid = None
        if self.op == "HardSigmoid":
            hardsigmoid = self
            mul = look_for_node(self._graph, self, [FwdChainNode(op="Mul")])
        elif self.op == "Mul":
            mul = self
            hardsigmoid = look_for_node(self._graph, self, [BwdChainNode(op="HardSigmoid")])

        if hardsigmoid and mul:
            alpha, beta, hardsigmoid_consumed_vertices = hardsigmoid.get_hardsigmoid_info()
            common_stem = self._has_common_stem(hardsigmoid, mul)
            if common_stem and np.isclose(alpha, 0.16666666) and beta == 0.5:
                return True, [hardsigmoid, mul, *hardsigmoid_consumed_vertices]

        return False, []

    def is_hardswish_activation(self):
        if self.op in ADD_OPS:
            possible_chains = [
                [FwdChainNode(op="Clip"), FwdChainNode(op="Div"), FwdChainNode(op="Mul")],
                [FwdChainNode(op="Clip"), FwdChainNode(op="Mul"), FwdChainNode(op="Div")],
            ]
            hardswish_vertices = get_all_nodes_from_possible_chains(self._graph, self, possible_chains)

            if not hardswish_vertices:
                return False, []
            add_node = self
            clip_node, div_node, mul_node = hardswish_vertices
            if hardswish_vertices[2].op == "Div":
                div_node = hardswish_vertices[2]
                mul_node = hardswish_vertices[1]

        elif self.op in MUL_OPS:
            possible_chains = [
                [BwdChainNode(op="Div"), BwdChainNode(op="Clip"), BwdChainNode(op="Add")],
                [BwdChainNode(op="Clip"), BwdChainNode(op="Add")],
            ]
            hardswish_vertices = get_all_nodes_from_possible_chains(self._graph, self, possible_chains)
            if not hardswish_vertices:
                return False, []

            mul_node = self
            if len(hardswish_vertices) == 3:
                div_node, clip_node, add_node = hardswish_vertices
            else:
                clip_node, add_node = hardswish_vertices
                div_node = look_for_node(self._graph, self, [FwdChainNode(op="Div")])
                if not div_node:
                    return False, []
                hardswish_vertices.append(div_node)

        elif self.op == "HardSwish":
            return True, []

        else:
            return False, []

        consumed_vertices = [self]
        consumed_vertices.extend(hardswish_vertices)
        common_stem = self._has_common_stem(add_node, mul_node)

        add_mean, add_std, add_nodes = add_node.get_normalization_info()
        consumed_vertices.extend(add_nodes)
        is_add_3 = all(x == -3.0 for x in add_mean) and all(x == 1.0 for x in add_std)

        div_mean, div_std, div_nodes = div_node.get_normalization_info()
        consumed_vertices.extend(div_nodes)
        is_div_6 = all(x == 0.0 for x in div_mean) and all(x == 6.0 for x in div_std)

        min_value, max_value, relu6_vertices = clip_node.get_min_max_clip_info()
        consumed_vertices.extend(relu6_vertices)
        is_relu6 = min_value == 0.0 and max_value == 6.0

        if common_stem and mul_node.is_ew_mult() and is_relu6 and is_add_3 and is_div_6:
            return True, consumed_vertices

        return False, []

    def is_prelu_activation(self):
        if self.op == "Relu":
            relu = self
            neg = look_for_node(
                self._graph,
                self,
                [FwdChainNode(op="Add"), BwdChainNode(op="Mul"), BwdChainNode(op="Relu"), BwdChainNode(op="Neg")],
            )
            if neg is None:
                return False

        elif self.op in NEG_OPS:
            neg = self
            relu = look_for_node(
                self._graph,
                self,
                [FwdChainNode(op="Relu"), FwdChainNode(op="Mul"), FwdChainNode(op="Add"), BwdChainNode(op="Relu")],
            )
            if relu is None:
                return False

        else:
            return False

        input_shapes = self.get_input_shapes(convert_to_nhwc=False)
        if not input_shapes:
            return False
        input_shape = input_shapes[0]
        features = input_shape[1]
        mul = look_for_node(self._graph, neg, [FwdChainNode(op="Relu"), FwdChainNode(op="Mul")])
        prelu_slope, _, _ = mul.get_normalization_input_raw_values()
        if prelu_slope is None or len(prelu_slope.flatten()) not in [1, features]:
            return False

        return self._has_common_stem(relu, neg)

    def is_decimal_fraction_pow_activation(self):
        if self.op in POW_OPS:
            power, consumed_vertices = self.get_power()
            # Hailo-8 support only power of decimal < 1 due to the piecewise activation implementation.
            # Hailo-15 will support powers > 1, which is preferable from current square support, btw.
            if power < 1.0:
                return True, power, consumed_vertices
        return False, None, []

    def is_null_add(self):
        return self.op == "Add" and self.is_additive_mask_for_softmax()

    def is_broadcast_expand(self):
        if self.op != "Expand":
            return False
        successors = list(self._graph.successors(self))
        if len(successors) == 1 and successors[0].is_ew_op():
            succ_preds = list(self._graph.predecessors(successors[0]))
            if len(succ_preds) == 2:
                for pred in succ_preds:
                    if pred != self and any(
                        self.get_output_shapes(convert_to_nhwc=False)[0] == x
                        for x in pred.get_output_shapes(convert_to_nhwc=False)
                    ):
                        return True
        return False

    def is_null_squeeze(self):
        if self.op != "Squeeze":
            return False

        chains = [[FwdChainNode(op="Unsqueeze")], [FwdChainNode(op="Squeeze")]]
        successor = get_node_from_possible_chains(self._graph, self, chains)
        output_node = successor if successor is not None else self

        input_shapes = self.get_input_shapes(convert_to_nhwc=False)
        output_shapes = output_node.get_output_shapes(convert_to_nhwc=False)
        axes = self.get_axes_information()

        if input_shapes and output_shapes:
            input_shape, output_shape = input_shapes[0], output_shapes[0]
            if input_shape == output_shape:
                return True

            if len(input_shape) == 4:
                if len(output_shape) == 2:
                    return output_shape[1] == input_shape[1] * input_shape[2] * input_shape[3]

                # Squeeze batch dim and spatial dims, allowed only for last node in graph
                if len(output_shape) == 1 and len(list(self.graph.successors(output_node))) == 0:
                    return output_shape[0] == input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3]

                if len(output_shape) == 3 and len(axes) == 1:
                    return output_shape == [x for i, x in enumerate(input_shape) if i != axes[0]]

            elif len(input_shape) == 5 and len(output_shape) == 4:
                return input_shape.pop(axes[0]) == 1 and input_shape == output_shape

            elif len(output_shape) == 3:
                return np.prod(output_shape[1:]) == np.prod(input_shape[1:])

        if axes is not None:
            if axes == [0]:
                output_shapes = self.get_output_shapes(convert_to_nhwc=False)
                if input_shapes[0][0:2] == [1, 1] and output_shapes[0][0] == 1:
                    return True

            if successor is not None:
                # might be the case of squeeze unsqueeze
                if successor.op == "Unsqueeze":
                    successor_axes = successor.get_axes_information()
                    if successor_axes is not None and successor_axes == axes:
                        # squeeze and unsqueeze on the same axis
                        return True

                if successor.op == "Squeeze":
                    # case of successive squeeze
                    axes = successor.get_axes_information() + axes

            if axes in ([2, 3], [2, 2]):
                return True

        if self.get_input_shapes() == self.get_output_shapes():
            # squeeze over stack / groups
            return True

        # Squeeze part of spatial flatten + unsqueeze + null transpose + squeeze chain
        chain = [BwdChainNode(op="Transpose"), BwdChainNode(op="Unsqueeze"), BwdChainNode(op="Reshape")]
        nodes = get_all_nodes_in_chain(self._graph, self, chain)
        return bool(nodes and nodes[-1].is_spatial_flatten_reshape())

    def is_null_unsqueeze(self):
        if self.op != "Unsqueeze":
            return False

        succs = list(self._graph.successors(self))
        if self.is_successive_unsqueeze_flat_to_frame():
            axes = self.get_axes_information()
            return len(succs) == 1 and succs[0].op == "Squeeze" and succs[0].get_axes_information() == axes

        return not (self.is_unsqueeze_resize_nearest() or self.is_unsqueeze_tile())

    def is_unsqueeze_tile(self):
        if self.op != "Unsqueeze":
            return False
        chains = [
            [
                BwdChainNode(op="Einsum"),
                FwdChainNode(op="Unsqueeze"),
                FwdChainNode(op="Add"),
                BwdChainNode(op="Add"),
                BwdChainNode(op="Reshape"),
            ],
        ]
        reshape = get_node_from_possible_chains(self._graph, self, chains)
        if reshape:
            input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
            output_shape = reshape.get_output_shapes(convert_to_nhwc=False)[0]
            if len(input_shape) == 5 and len(output_shape) == 6 and input_shape[-1] == output_shape[-1]:
                return True
        return False

    def is_unsqueeze_to_stack(self):
        if self.op != "Unsqueeze":
            return False

        chain = [BwdChainNode("Gather"), BwdChainNode("Reshape")]
        chains = [chain, [BwdChainNode("Neg")] + chain]
        reshape = get_node_from_possible_chains(self._graph, self, chains)
        if reshape:
            reshape_pred = next(iter(self._graph.predecessors(reshape)))
            if reshape_pred.output_format:
                is_f_to_g, reshape_format = reshape.is_features_to_groups_reshape(reshape_pred.output_format)
                return is_f_to_g and Dims.STACK in reshape_format
        return False

    def get_squeeze_axes(self):
        axes = self.get_axes_information()
        if not axes:
            input_shapes = self.get_input_shapes(convert_to_nhwc=False)
            if input_shapes:
                axes = [i for i, shape in enumerate(input_shapes[0]) if shape == 1]
        return axes

    def get_axes_information(self, convert_to_nhwc=False):
        possible_input_orders = {
            "Squeeze": SQUEEZE_INPUT_ORDER,
            "Unsqueeze": UNSQUEEZE_INPUT_ORDER,
            "Slice": SLICE_INPUT_ORDER,
        }
        possible_input_orders.update({reduce_op: REDUCE_INPUT_ORDER for reduce_op in REDUCE_OPS})

        # option 1: initializer variable
        values = self._graph.values_by_vertex_name.get(self.name, None)
        input_order = possible_input_orders[self.op]
        var_index = input_order.index("axes")
        if values and var_index < len(self._info.input):
            axes = values[self._info.input[var_index]].tolist()
        else:
            # option 2: operator attribute
            axes = self.get_attribute_by_name("axes")
            if axes:
                axes = axes[0].ints
            else:
                # option 3: input const operator
                pred_const = look_for_node(self._graph, self, [BwdChainNode(op="Constant")])
                if pred_const is not None:
                    axes = pred_const.parse_raw_data(cast_to_int=True).flatten().tolist()
        if axes and convert_to_nhwc:
            axes = self._convert_axes_to_nhwc(axes)
        return axes

    def _convert_axes_to_nhwc(self, axes):
        nchw_to_nhwc_axis_mapping = {
            Dims.CHANNELS: 3,
            Dims.HEIGHT: 1,
            Dims.WIDTH: 2,
        }
        if len(self.get_input_shapes()[0]) == 4:
            return [nchw_to_nhwc_axis_mapping[self.input_format[axis]] for axis in axes]

        return axes

    def is_resize_slice(self):
        return look_for_node(self._graph, self, [FwdChainNode(op="Concat"), FwdChainNode(op="Resize")]) is not None

    def is_upsample_concat(self):
        possible_chains = [
            [FwdChainNode(op="Cast"), FwdChainNode(op="Concat"), FwdChainNode(op="Resize")],
            [FwdChainNode(op="Resize")],
            # Upsample chains (older opset compatibility)
            [FwdChainNode(op="Upsample")],
            [FwdChainNode(op="Cast"), FwdChainNode(op="Div"), FwdChainNode(op="Concat"), FwdChainNode(op="Upsample")],
        ]

        upsample_node = get_node_from_possible_chains(self._graph, self, possible_chains)
        preds = list(self._graph.predecessors(self))
        return upsample_node is not None and len(preds) == 1 and not self.is_concat_with_new_input()

    def is_upsample_div(self):
        possible_chain = [FwdChainNode(op="Concat"), FwdChainNode(op="Upsample")]
        upsample_node = look_for_node(self._graph, self, possible_chain)
        return upsample_node is not None

    def is_l2_norm_div(self):
        possible_chains = [[BwdChainNode(op="Unsqueeze"), BwdChainNode(op="ReduceL2")], [BwdChainNode(op="ReduceL2")]]
        reduce_l2_node = get_node_from_possible_chains(self._graph, self, possible_chains)
        if reduce_l2_node is None:
            return False
        reduce_l2_pred = next(iter(self.graph.predecessors(reduce_l2_node)))
        axes = reduce_l2_node.get_attribute_by_name("axes")
        axes = axes[0].ints if axes else reduce_l2_node.get_initializer_or_constant_value(MIN_INPUT_ORDER)
        return reduce_l2_pred.output_format[axes[0]] == Dims.CHANNELS

    def is_softmax(self):
        if self.op == "Transpose":
            perm1 = self.get_transpose_perm()
            softmax_chain = get_all_nodes_in_chain(
                self._graph,
                self,
                [FwdChainNode(op="Softmax"), FwdChainNode(op="Transpose")],
            )
            softmax_axis = softmax_chain[0].get_axis() if softmax_chain else 0
            perm2 = softmax_chain[1].get_attribute_by_name("perm")[0].ints if softmax_chain else []
            cond = (
                softmax_chain
                and softmax_axis == 3
                and (
                    (perm1 == [0, 2, 3, 1] and perm2 == [0, 3, 1, 2])
                    or (perm1 == [0, 3, 2, 1] and perm2 == [0, 3, 2, 1])
                )
            )
            if cond and softmax_axis not in [-1, 3]:
                raise UnsupportedLogitsLayerError(
                    f"Softmax layer near vertex {self.name}, has unsupported axis {softmax_axis} (should be -1 or 3)",
                )
            return cond

        if self.op == "Exp":
            exp_node = self
            reduce_sum_node = look_for_node(self._graph, self, [FwdChainNode(op="ReduceSum")])
            div_node = look_for_node(self._graph, self, [FwdChainNode(op="ReduceSum"), FwdChainNode(op="Div")])
            div_node_direct = look_for_node(self._graph, self, [FwdChainNode(op="Div")])
        elif self.op == "ReduceSum":
            reduce_sum_node = self
            exp_node = look_for_node(self._graph, self, [BwdChainNode(op="Exp")])
            div_node = look_for_node(self._graph, self, [FwdChainNode(op="Div")])
            div_node_direct = look_for_node(self._graph, self, [BwdChainNode(op="Exp"), FwdChainNode(op="ReduceSum")])

        if not exp_node or not reduce_sum_node or not div_node or not div_node_direct:
            return False

        return reduce_sum_node.get_axes_information()[0] == 1

    def is_shuffle(self):
        if self.op != "Reshape":
            return False

        last_node = look_for_node(self._graph, self, [FwdChainNode(op="Transpose"), FwdChainNode(op="Reshape")])
        if last_node is not None and not last_node.is_null_operation():
            chains = [[FwdChainNode(op="Transpose")], [FwdChainNode(op="Reshape")]]
            last_node = get_node_from_possible_chains(self._graph, last_node, chains)
            if not last_node or last_node.get_output_shapes() != self.get_input_shapes():
                first_reshape, _, perm, _ = self.get_shuffle_reshape_transpose_info()
                if len(first_reshape) == 5 and perm == [0, 2, 1, 3, 4]:
                    return True

                if len(first_reshape) == 6 and perm in [[0, 3, 4, 1, 5, 2], [0, 1, 4, 2, 5, 3], [0, 1, 3, 5, 2, 4]]:
                    return True

        else:
            # edge case of qwen2_vl_vision, WC: [4x, y] -> [x, 4y]
            in_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
            out_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
            if len(in_shape) == len(out_shape) == 2 and in_shape[0] / out_shape[0] == out_shape[1] / in_shape[1]:
                return True

        return False

    def is_channel_shuffle_null_ops(self):
        if self.op not in ["Transpose", "Reshape"]:
            return False

        fallback_reshape = None
        if self.op == "Reshape":
            if self.is_channel_shuffle_reshape():
                return True
            else:
                fallback_reshape = look_for_node(
                    self._graph, self, [BwdChainNode(op="Transpose"), BwdChainNode(op="Reshape")]
                )
        else:
            fallback_reshape = look_for_node(self._graph, self, [BwdChainNode(op="Reshape")])

        if fallback_reshape and fallback_reshape.is_channel_shuffle_reshape():
            return True

        return False

    def is_channel_shuffle_reshape(self):
        if self.op != "Reshape":
            return False

        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        output_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
        flatten_and_split_cond = (
            len(input_shape) == 4
            and len(output_shape) == 3
            and output_shape[0] * output_shape[1] == input_shape[1]
            and output_shape[2] == input_shape[2] * input_shape[3]
        )

        if not flatten_and_split_cond:
            return False

        nodes = get_all_nodes_from_possible_chains(
            self.graph, self, [[FwdChainNode(op="Transpose"), FwdChainNode(op="Reshape")]]
        )
        if not nodes:
            return False

        transpose_perm_cond = nodes[0].get_transpose_perm() == [1, 0, 2]
        if not transpose_perm_cond:
            return False

        second_reshape_output_shapes = nodes[1].get_output_shapes(convert_to_nhwc=False)
        second_reshape_cond = (
            all(x == second_reshape_output_shapes[0] for x in second_reshape_output_shapes)
            and len(second_reshape_output_shapes[0]) == 5
            and np.prod(second_reshape_output_shapes[0][0:3]) == input_shape[1]
        )
        if not second_reshape_cond:
            return False

        gathers = [x for x in self.graph.successors(nodes[1]) if x.op == "Gather"]
        gathers_values = [list(x.graph.values_by_vertex_name.get(x.name, {}).values()) for x in gathers]
        gathers_cond = (
            len(gathers) == 2 and any(x == [0] for x in gathers_values) and any(x == [1] for x in gathers_values)
        )
        return gathers_cond

    def is_channel_shuffle_gather_slice(self):
        if self.op != "Gather":
            return False

        reshape_node = look_for_node(
            self.graph, self, [BwdChainNode(op="Reshape"), BwdChainNode(op="Transpose"), BwdChainNode(op="Reshape")]
        )

        return reshape_node is not None and reshape_node.is_channel_shuffle_reshape()

    def is_ew_mult(self):
        if self.op in MUL_OPS:
            return self.is_ew_op()
        return False

    def is_ew_add(self):
        if self.op in ADD_OPS:
            return self.is_ew_op()
        return False

    def is_ew_sub(self):
        if self.op == "Sub":
            return self.is_ew_op()
        return False

    def is_ew_div(self):
        if self.op in DIV_OPS:
            return self.is_ew_op()
        return False

    def is_ew_max(self):
        if self.op in MAX_OPS:
            return self.is_ew_op()
        return False

    def is_ew_min(self):
        if self.op in MIN_OPS:
            return self.is_ew_op()
        return False

    def is_ew_op(self):
        non_const_preds = [x for x in self.graph.predecessors(self) if not x.is_const()]
        return len(non_const_preds) == 2 or self.is_ew_op_with_const_input()

    def is_mul_by_2_ew_add(self):
        if self.op == "Add":
            return len(self.input) == 2 and self.input[0] == self.input[1]
        return False

    def is_const_reshape(self):
        if self.op != "Reshape":
            return False, None

        reshape_data_input = self._info.input[RESHAPE_INPUT_ORDER.index("X")]
        pred = self.graph.vertices_by_inp_key[reshape_data_input]
        if pred.op in CONST_OPS:
            return True, pred

        return False, None

    def is_const_flatten(self):
        if self.op == "Flatten":
            preds = list(self.graph.predecessors(self))
            if preds and preds[0].op in CONST_OPS:
                return True, preds[0]

        return False, None

    def is_const_unsqueeze(self):
        if self.op == "Unsqueeze":
            preds = list(self.graph.predecessors(self))
            if preds and preds[0].op in CONST_OPS:
                return True, preds[0]
            # check if values exist in data input index (0)
            if self._graph.values_by_vertex_name.get(self.name, {}).get(self._info.input[0]) is not None:
                return True, []

        return False, None

    def is_const(self):
        if self.op in CONST_OPS:
            return True

        if self.op != "input" and any(output + "_value" in self._graph.output_shapes for output in self._info.output):
            return True

        if self.op == "Reshape":
            return self.is_const_reshape()[0]

        if self.op == "Flatten":
            return self.is_const_flatten()[0]

        if self.op == "Unsqueeze":
            return self.is_const_unsqueeze()[0]

        return False

    def is_multiply_by_const(self):
        if self.op in MUL_OPS:
            preds = list(self.graph.predecessors(self))
            return (
                len(preds) == 2 and any(x.op in CONST_OPS for x in preds) and any(x.op not in CONST_OPS for x in preds)
            )
        return False

    def get_normalization_input_raw_values(self):
        consumed_vertices = []
        possible_input_indices = []
        if self.op in EW_OPS + EQUAL_OPS + ["Greater", "Less"]:
            possible_input_indices = [0, 1]
        vertex_values = self._graph.values_by_vertex_name.get(self.name, {})

        for index in possible_input_indices:
            value_input = self._info.input[index]

            # Try to get the values from variable initializer inputs
            if value_input in vertex_values:
                return vertex_values[value_input], [], index

            # Try to get the values from constant inputs
            if value_input in self._graph.vertices_by_inp_key:
                input_vertex = self._graph.vertices_by_inp_key[value_input]

                is_shape_expand, shape_expand_nodes = self.is_shape_expand_norm()
                if is_shape_expand:
                    _, expand_node = shape_expand_nodes
                    vertex_params = self._graph.values_by_vertex_name.get(expand_node.name, None)
                    if vertex_params:
                        consumed_vertices.extend(shape_expand_nodes)
                        return next(iter(vertex_params.values())), consumed_vertices, index

                if input_vertex.op == "Cast":
                    cast = input_vertex
                    consumed_vertices.append(cast)
                    case_input = cast.input[0]
                    if case_input in self._graph.vertices_by_inp_key:
                        input_vertex = self._graph.vertices_by_inp_key[case_input]

                if input_vertex.is_const():
                    consumed_vertices.append(input_vertex)
                    value_input_key = value_input + "_value"
                    if value_input_key in self._graph.output_shapes:
                        return self._graph.output_shapes[value_input_key], consumed_vertices, index
                    is_reshape, const = input_vertex.is_const_reshape()
                    if is_reshape:
                        input_vertex = const
                        consumed_vertices.append(input_vertex)
                    is_flatten, const = input_vertex.is_const_flatten()
                    if is_flatten:
                        input_vertex = const
                        consumed_vertices.append(input_vertex)
                    is_unsqueeze, const = input_vertex.is_const_unsqueeze()
                    if is_unsqueeze:
                        input_vertex = const
                        consumed_vertices.append(input_vertex)

                    return input_vertex.parse_raw_data(), consumed_vertices, index

        return None, [], -1

    def _get_normalization_input_values(self):
        values, _, _ = self.get_normalization_input_raw_values()
        if values is not None:
            values = values.flatten().tolist()
            values = values if isinstance(values, list) else [values]

        return values

    def is_normalization(self):
        # Possible cases (always assume start from the first encountered operator):
        # 1. Add/Sub -> Mul/Div: normalization of the form (x-Mean(x))/(Std(x))
        # 2. Mul/Div -> Add/Sub: equivalent normalization of the form (x-Mean(x)*Std(x))/(Std(x))
        # 3. Add/Sub: normalization with Std(x)=1 (private case Neg has Std(x)=-1)
        # 4. Mul/Div: normalization with Mean(x)=0
        # 5. Shape -> Expand -> Mul/Add
        # Note that for this predicate, 3+4 actually covers 1+2 respectively
        if self.is_shape_expand_norm()[0]:
            return True

        if self.op in NEG_OPS:
            return True

        if self.is_ew_op():
            return False

        if self.op in ["Add", "Sub"]:
            values = self._get_normalization_input_values()
            if values is None:
                return False

            return bool(len(values) > 0 and all(isinstance(v, float) for v in values))

        if self.op in ["Mul", "Div"] and self._get_normalization_input_values() is not None:
            if self.op == "Div" and self.is_gelu_activation()[0]:
                return False

            values = self._get_normalization_input_values()
            return values is not None

        return False

    def _get_split_sizes(self):
        split_sizes = None
        split_attr = self.get_attribute_by_name("split")
        if split_attr:
            split_sizes = split_attr[0].ints
        elif len(self._info.input) > 1:
            splits_key = self._info.input[SPLIT_INPUT_ORDER.index("split")]
            var_initializers = self._graph.values_by_vertex_name[self.name]
            if var_initializers:
                # extracts values from initializer
                split_sizes = [int(x) for x in var_initializers[splits_key].tolist()]
            else:
                # extracts values from const input
                const = self._graph.vertices_by_inp_key[splits_key]
                if const:
                    split_sizes = self._graph.vertices_by_inp_key[splits_key].parse_raw_data().tolist()

        return split_sizes

    def is_spatial_splitter(self):
        if self.op not in SPLIT_OPS:
            return False

        split_sizes = self._get_split_sizes()
        if not split_sizes:
            return False

        axis = self.get_axis()
        output_format = self.output_format
        return output_format is not None and output_format[axis] in [Dims.HEIGHT, Dims.WIDTH]

    def get_normalization_activation(self, consumed_vertices):
        for vertex in [self, *consumed_vertices]:
            if vertex.op in DIV_OPS:
                _, _, input_idx = vertex.get_normalization_input_raw_values()
                if input_idx == 0:  # Div(c, x)
                    return ActivationType.inv_pos
        return ActivationType.linear

    def get_normalization_info(self):
        # Possible cases (always assume start from the first encountered operator):
        # 1. Add/Sub -> Mul/Div: normalization of the form (x-Mean(x))/(Std(x))
        # 2. Mul/Div -> Add/Sub: equivalent normalization of the form (x-Mean(x)*Std(x))/(Std(x))
        # 3. Add/Sub: normalization with Std(x)=1 (private case Neg has Std(x)=-1)
        # 4. Mul/Div: normalization with Mean(x)=0
        # 5. Shape -> Expand -> Mul/Add
        if self.op in ["Add", "Sub"]:
            # covering case #3
            raw_mean, _, mean_input_idx = self.get_normalization_input_raw_values()
            # either add -mean or subtract mean
            mean = (np.negative(raw_mean) if self.op == "Add" else raw_mean).flatten().tolist()
            mean = mean if isinstance(mean, list) else [mean]

            # covering case #1
            std = [1.0]
            std_node = get_node_from_possible_chains(self._graph, self, [[FwdChainNode("Mul")], [FwdChainNode("Div")]])
            if std_node is not None and std_node.is_normalization():
                raw_std, _, std_input_idx = std_node.get_normalization_input_raw_values()
                # either multiply by 1/std or divide by std
                std = (np.reciprocal(raw_std) if std_node.op in MUL_OPS else raw_std).flatten().tolist()
                std = std if isinstance(std, list) else [std]
            else:
                std_node = None

            if self.op == "Sub" and mean_input_idx == 0:  # Sub(c, x)
                std = [-1 * x for x in std]
            return mean, std, [std_node] if std_node is not None else []

        if self.op in NEG_OPS:
            return [0.0], [-1.0], []  # can also be seen as case #3

        if self.op in ["Mul", "Div"]:
            # covering case #4
            raw_std, _, std_input_idx = self.get_normalization_input_raw_values()
            if raw_std is None:
                raise UnsupportedNormalizationLayerError(
                    f"Could not find std values in normalization starting from node {self.name}",
                )

            # either multiply by 1/std or divide by std (except for a mul-by-const scenario)
            raw_std = raw_std.astype(np.float32) if np.issubdtype(raw_std.dtype, np.integer) else raw_std
            if self.op in MUL_OPS:
                raw_std = np.reciprocal(raw_std, where=raw_std != 0).flatten()
                raw_std[raw_std == 0] = np.inf
            std = raw_std.flatten().tolist()
            if self.output_format and Dims.GROUPS in self.output_format:
                out_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
                if len(std) == out_shape[self.output_format.index(Dims.GROUPS)]:
                    std = np.repeat(std, out_shape[self.output_format.index(Dims.CHANNELS)])

            # covering case #2: multiply mean by std
            mean_node = None
            mean = [0.0]

            # Edge cases where case #2 can't be represented by normalization:
            # 1. mul by 0, the std will be np.inf
            # 2. Div(c, x)
            if np.isfinite(std).all() and (self.op == "Mul" or std_input_idx == 1):
                chains = [[FwdChainNode("Sub")], [FwdChainNode("Add")]]
                mean_node = get_node_from_possible_chains(self._graph, self, chains)

                if mean_node is not None and mean_node.is_normalization():
                    raw_mean, _, mean_input_idx = mean_node.get_normalization_input_raw_values()
                    # either add -mean or subtract mean
                    raw_mean = (np.negative(raw_mean) if mean_node.op in ADD_OPS else raw_mean).flatten().tolist()
                    raw_mean = raw_mean if isinstance(raw_mean, list) else [raw_mean]
                    extended_std = std * len(raw_mean) if len(std) == 1 else std
                    mean = [x * y for x, y in zip(raw_mean, extended_std)]
                    if mean_node.op == "Sub" and mean_input_idx == 0:  # Sub(c, x) after Div(x, d)
                        std = [-1 * x for x in std]
                else:
                    mean_node = None

            return mean, std, [mean_node] if mean_node is not None else []

        raise UnsupportedNormalizationLayerError(
            f"Could not find std/mean values in normalization starting from node {self.name}",
        )

    def get_mult_scalar(self):
        const_preds = [pred for pred in self._graph.predecessors(self) if pred.op in CONST_OPS]
        scalar_node = const_preds[0] if const_preds else None
        if not scalar_node:
            raise UnsupportedMultLayerError(f"Could not find scalar value in standalone mult layer {self.name}")
        scalar = scalar_node.parse_raw_data().flatten().tolist()
        return scalar, [scalar_node]

    def get_softmax_nodes(self):
        if self.op == "ReduceSum":
            nodes = [
                look_for_node(self._graph, self, [BwdChainNode(op="Exp")]),
                self,
                look_for_node(self._graph, self, [FwdChainNode(op="Div")]),
            ]
        elif self.op == "Exp":
            nodes = [self]
            nodes.extend(
                get_all_nodes_in_chain(self._graph, self, [FwdChainNode(op="ReduceSum"), FwdChainNode(op="Div")]),
            )
        elif self.op == "Transpose":
            nodes = [self]
            nodes.extend(
                get_all_nodes_in_chain(self._graph, self, [FwdChainNode(op="Softmax"), FwdChainNode(op="Transpose")]),
            )
        return nodes

    def get_channel_shuffle_slice_info(self):
        shuffle_reshape = look_for_node(
            self.graph, self, [BwdChainNode(op="Reshape"), BwdChainNode(op="Transpose"), BwdChainNode(op="Reshape")]
        )

        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        block_input_shape = shuffle_reshape.get_input_shapes(convert_to_nhwc=False)[0]
        groups = int(block_input_shape[1] // input_shape[0])
        gather_index = int(list(self.graph.values_by_vertex_name.get(self.name, {}).values())[0])

        h_slice = [0, input_shape[-2]]
        w_slice = [0, input_shape[-1]]
        f_slice = [gather_index, gather_index + 1]

        return h_slice, w_slice, f_slice, groups, []  # empty list for consumed vertices

    def get_feature_split_info(self):
        axis = self.get_axis()
        groups = 1
        onnx_input_shapes = self.get_input_shapes(convert_to_nhwc=False)
        onnx_output_shapes = self.get_output_shapes(convert_to_nhwc=False)

        is_kqv_split = axis == 0 and onnx_input_shapes and len(onnx_input_shapes[0]) == 5
        is_features_axis = (self.output_format is None and axis == 1) or (
            self.output_format and self.output_format[axis] == Dims.CHANNELS
        )
        is_groups_axis = self.output_format and self.output_format[axis] in [Dims.GROUPS, Dims.STACK]
        if not (is_features_axis or is_kqv_split or is_groups_axis):
            raise UnsupportedFeatureSplitterLayerError(
                f"Feature splitter vertex {self.name} is splitting input over unsupported axis {axis}",
            )

        hailo_output_shapes = self.get_output_shapes(validate_zero_dims=True)
        if is_kqv_split:
            hailo_output_shapes = [self.convert_kqv_split_shape_to_nhwc(shape) for shape in onnx_output_shapes]
            split_sizes = [
                self.convert_kqv_split_shape_to_nhwc(self._graph.output_shapes[output][0])[-1]
                for output in self._info.output
            ]
        else:
            split_sizes = [
                self.convert_nchw_to_nhwc(self._graph.output_shapes[output][0], self.output_format)[-1]
                for output in self._info.output
            ]
        if is_features_axis and self.output_format and Dims.GROUPS in self.output_format:
            groups = onnx_output_shapes[0][self.output_format.index(Dims.GROUPS)]

        return split_sizes, hailo_output_shapes, groups

    def get_spatial_split_info(self):
        axis = 1 if self.output_format[self.get_axis()] == Dims.HEIGHT else 2
        split_sizes = self._get_split_sizes()
        output_shapes = self.get_output_shapes_by_info()
        return axis, split_sizes, output_shapes

    def convert_kqv_split_shape_to_nhwc(self, shape):
        """
        Convert from shape format [batch, kqv(3), heads, width, channels] to nhwc
        """
        return [shape[0] * shape[1], 1, shape[3], shape[2] * shape[4]]

    def get_axis(self):
        axis_attr = self.get_attribute_by_name("axis")
        return axis_attr[0].i if axis_attr else None

    def get_keepdims(self):
        keepdims_attr = self.get_attribute_by_name("keepdims")
        if keepdims_attr:
            return bool(keepdims_attr[0].i)
        return True

    def get_vertex_successors_io_indices(self):
        succs_by_inp_idx = {}
        for succ in list(self._graph.successors(self)):
            input_indices = [x for x in succ.input if f"{self.name}{VERTEX_NAME_SEPARATOR}" in x]
            if succ.op in [*MUL_OPS, "Add"]:
                input_indices = [input_indices[0]]

            for input_index in input_indices:
                if input_index in succs_by_inp_idx:
                    succs_by_inp_idx[input_index].append(succ)
                else:
                    succs_by_inp_idx[input_index] = [succ]

        def sort_func(x):
            # splits x by ':' that stands alone and not part of '::' - onnx::resize:25 -> [onnx::resize, 25]
            split_x = re.split("(?<!:):(?!:)", x)
            return list(self.output).index(split_x[0]) if len(split_x) == 1 else list(self.output).index(split_x[1])

        return {i: succs_by_inp_idx[x] for i, x in enumerate(sorted(succs_by_inp_idx.keys(), key=sort_func))}

    def get_reshape_shapes(self):
        consumed_vertices = []
        output_shapes = self.get_output_shapes(convert_to_nhwc=False)
        if output_shapes:
            consumed_vertices = self.get_all_shape_nodes()
            return output_shapes[0], consumed_vertices

        # Fallback from shape inference
        possible_chains = [[BwdChainNode(op="Cast")], [BwdChainNode(op="Constant")], [BwdChainNode(op="Concat")]]
        shape_vertex = get_node_from_possible_chains(self._graph, self, possible_chains)
        if shape_vertex:
            consumed_vertices.append(shape_vertex)
            if shape_vertex.op == "Cast":
                shape_vertex = look_and_validate(self._graph, shape_vertex, [BwdChainNode(op="Constant")])
                consumed_vertices.append(shape_vertex)
            elif shape_vertex.op == "Concat":
                concat_preds = list(self._graph.predecessors(shape_vertex))
                consumed_vertices.extend(concat_preds)
                for pred in concat_preds:
                    if pred.op != "Unsqueeze":
                        raise UnsupportedReshapeError(
                            f"Expected Unsqueeze node before concat node in reshape "
                            f"{self.name}, but found node of type {pred.op}",
                        )
                    chain = [BwdChainNode(op="Gather"), BwdChainNode(op="Shape")]
                    consumed_vertices.extend(get_all_nodes_in_chain(self._graph, pred, chain))
                shapes = [-1] * len(shape_vertex._info.input)
                const_values = self._graph.values_by_vertex_name[shape_vertex.name]
                for key, value in const_values.items():
                    shapes[list(shape_vertex._info.input).index(key)] = value[0]
                return shapes, consumed_vertices

            shapes = shape_vertex.parse_raw_data(cast_to_int=True).flatten().tolist()
            return shapes, consumed_vertices

        if self.graph.values_by_vertex_name.get(self.name):
            # covers a more basic implementation where a single reshape op is used
            shapes = self.graph.values_by_vertex_name[self.name][self._info.input[-1]].tolist()
            if len(shapes) == 4:
                shapes = [shapes[0], shapes[2], shapes[3], shapes[1]]
            return shapes, consumed_vertices

        raise UnsupportedReshapeError(f"Couldn't find reshape shapes value in {self.name}")

    def get_all_shape_nodes(self):
        def get_all_shape_nodes_helper(node, result):
            for pred in list(node.graph.predecessors(node)):
                if (
                    (pred.op in SHAPE_OPS + LOGICAL_OPS + MATH_OPS and not pred.is_null_operation())
                    or pred.is_shape_concat()
                ) and pred not in result:
                    result.append(pred)
                    result = get_all_shape_nodes_helper(pred, result)
            return result

        return get_all_shape_nodes_helper(self, [])

    def is_shape_concat(self):
        return self.op in CONCAT_OPS and all(x.op in SHAPE_OPS for x in list(self._graph.predecessors(self)))

    def get_shuffle_reshape_transpose_info(self):
        consumed_vertices = []
        first_reshape, consumed_first_reshape = self.get_reshape_shapes()
        consumed_vertices.extend(consumed_first_reshape)

        transpose = look_and_validate(self._graph, self, [FwdChainNode(op="Transpose")])
        perm = transpose.get_attribute_by_name("perm")[0].ints
        consumed_vertices.append(transpose)

        second_reshape_vertex = look_and_validate(self._graph, transpose, [FwdChainNode(op="Reshape")])
        second_reshape, consumed_second_reshape = second_reshape_vertex.get_reshape_shapes()
        consumed_vertices.extend(consumed_second_reshape)
        consumed_vertices.append(second_reshape_vertex)

        return first_reshape, second_reshape, perm, consumed_vertices

    def is_f_to_w_transpose_reshape(self):
        if self.op != "Transpose":
            return False

        second_transpose_node = look_for_node(self._graph, self, [FwdChainNode("Reshape"), FwdChainNode("Transpose")])
        if not second_transpose_node:
            return False
        first_perm = self.get_transpose_perm()
        second_perm = second_transpose_node.get_transpose_perm()
        return first_perm == [0, 2, 3, 1] and second_perm == [0, 3, 1, 2]

    def get_f_to_w_transpose_reshape_info(self):
        consumed_vertices = []
        reshape_node = look_for_node(self._graph, self, [FwdChainNode(op="Reshape")])
        second_transpose_node = look_for_node(self._graph, self, [FwdChainNode("Reshape"), FwdChainNode("Transpose")])
        consumed_vertices.extend([reshape_node, second_transpose_node])
        shape, consumed_reshape = reshape_node.get_reshape_shapes()
        consumed_vertices.extend(consumed_reshape)
        return shape, consumed_vertices

    def is_flat_to_frames_reshape(self):
        if self.op != "Reshape":
            return False

        output_shapes = self.get_output_shapes()
        if output_shapes:
            compare_input_to_output_shape = True
            output_shape = output_shapes[0]
        else:
            compare_input_to_output_shape = False
            output_shape, _ = self.get_reshape_shapes()

        transpose_node = look_for_node(self._graph, self, [FwdChainNode(op="Transpose")])
        if transpose_node:
            perm = transpose_node.get_attribute_by_name("perm")[0].ints
            if (len(perm) == 4 and perm != [0, 3, 1, 2]) or (len(perm) == 3 and perm != [0, 2, 1]):
                return False

        if not compare_input_to_output_shape:
            return (len(output_shape) == 4 and output_shape[1] == 1) or len(output_shape) == 3

        output_shape = output_shape if not transpose_node else transpose_node.get_output_shapes()[0]
        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        if len(input_shape) != 2:
            return False
        if len(output_shape) == 3:
            return output_shape[1] * output_shape[2] == input_shape[1]
        if len(output_shape) == 4:
            return (output_shape[2] * output_shape[3] == input_shape[1]) and output_shape[1] == 1

        return False

    def get_flat_to_frames_reshape_info(self):
        consumed_vertices = []
        shape, consumed_reshape = self.get_reshape_shapes()
        transpose_node = look_for_node(self._graph, self, [FwdChainNode(op="Transpose")])
        if transpose_node:
            consumed_vertices.append(transpose_node)
            if len(shape) == 3:
                shape = [shape[0], 1, shape[1], shape[2]]
        elif len(shape) == 3:
            shape = [shape[0], 1, shape[2], shape[1]]
        elif len(shape) == 4:
            shape = [shape[0], shape[2], shape[3], shape[1]]

        consumed_vertices.extend(consumed_reshape)

        return shape, consumed_vertices

    def is_successive_unsqueeze_flat_to_frame(self):
        if self.op != "Unsqueeze":
            return False

        successive_unsqueeze = look_for_node(self._graph, self, [FwdChainNode(op="Unsqueeze")])
        if successive_unsqueeze is None:
            successive_unsqueeze = self
        successive_unsqueeze_output_shape = successive_unsqueeze.get_output_shapes(
            convert_to_nhwc=False,
        )[0]
        current_unsqueeze_input_shape = self.get_input_shapes()[UNSQUEEZE_INPUT_ORDER.index("X")]

        # input shape with rank 2 and output of the successive unsqueeze with rank 3 or 4 with 1s in its spatial
        # dimension
        return len(current_unsqueeze_input_shape) == 2 and len(successive_unsqueeze_output_shape) in [3, 4]

    def is_supported_one_hot(self):
        if self.op not in ONE_HOT_OPS:
            return False

        preds = list(self._graph.predecessors(self))
        one_hot_axis = self.get_axis()
        chain = get_all_nodes_in_chain(self._graph, self, [FwdChainNode("Transpose"), FwdChainNode("Squeeze")])
        if len(preds) == 1 and chain:
            perm = chain[0].get_transpose_perm()
            squeeze_axis = self._graph.values_by_vertex_name.get(chain[1].name, None)
            squeeze_axis = next(iter(squeeze_axis.values())).tolist()[0]
            if one_hot_axis == -1 and perm == [0, 4, 2, 3, 1] and squeeze_axis in [4, -1]:
                return True
        return False

    def get_one_hot_info(self):
        consumed = get_all_nodes_in_chain(self._graph, self, [FwdChainNode("Transpose"), FwdChainNode("Squeeze")])
        out_shapes = deepcopy(self.get_output_shapes())
        out_shapes[0].pop(1)
        return consumed, out_shapes

    def get_flat_to_frames_successive_unsqueeze_info(self):
        successive_unsqueeze = look_for_node(self._graph, self, [FwdChainNode(op="Unsqueeze")])
        current_unsqueeze_input_shape = self.get_input_shapes()[UNSQUEEZE_INPUT_ORDER.index("X")]
        shape = [current_unsqueeze_input_shape[0], 1, 1, current_unsqueeze_input_shape[1]]

        return shape, [successive_unsqueeze] if successive_unsqueeze else []

    def is_width_features_transpose(self):
        if self.op != "Transpose":
            return False

        if self.is_spatial_unflatten():
            return False

        if self.is_rnn_sequence():
            return True

        return self.get_transpose_perm() in [[0, 2, 1], [0, 3, 2, 1]]

    def is_height_width_transpose(self):
        if self.is_transpose_after_spatial_flatten():
            return False

        return self.op == "Transpose" and self.get_transpose_perm() == [0, 1, 3, 2]

    def is_transpose_after_spatial_flatten(self):
        preds = list(self._graph.predecessors(self))
        if not preds or len(preds) > 1:
            return False

        curr_pred = preds[0]
        while curr_pred.op != "Reshape":
            preds = list(self._graph.predecessors(curr_pred))
            if len(list(self._graph.successors(curr_pred))) != 1 or len(preds) != 1:
                return False
            curr_pred = preds[0]

        if not curr_pred.is_spatial_flatten_reshape():
            return False

        for sibling in self.graph.successors(curr_pred):
            while sibling.op != "Transpose":
                succs = list(self._graph.successors(sibling))
                if len(succs) != 1 or len(list(self._graph.predecessors(sibling))) != 1:
                    return False
                sibling = succs[0]

        return True

    def is_transpose_1d_already_flattened(self):
        if self.op != "Transpose":
            return False

        perm = self.get_transpose_perm()
        pred = next(iter(self.graph.predecessors(self)))
        perm_cond = perm == [0, 2, 1]
        pred_format_cond = pred.output_format == [Dims.BATCH, Dims.CHANNELS, Dims.WIDTH]
        net_format_cond = self.graph.net_input_format and any(
            x == [Dims.BATCH, Dims.CHANNELS, Dims.WIDTH] for x in self.graph.net_input_format.values()
        )
        return perm_cond and pred_format_cond and net_format_cond

    def is_transpose_before_spatial_flatten(self):
        if self.op != "Transpose":
            return False
        succs = list(self._graph.successors(self))
        return succs and all(succ.is_spatial_flatten_reshape() for succ in succs)

    def is_transpose_after_features_to_groups(self):
        if self.op != "Transpose":
            return False
        preds = list(self._graph.predecessors(self))
        if not preds or len(preds) != 1 or preds[0].op not in ["Reshape", "Unsqueeze"]:
            return False

        first_degree_pred = preds[0]
        if first_degree_pred.op == "Unsqueeze":
            # possible chain: reshape -> unsqeeze -> transpose
            second_degree_preds = list(self._graph.predecessors(first_degree_pred))
            if not second_degree_preds or len(second_degree_preds) != 1 or second_degree_preds[0].op != "Reshape":
                return False

            reshape_pred = second_degree_preds[0]
        else:
            reshape_pred = first_degree_pred

        input_format = next(iter(self._graph.predecessors(reshape_pred))).output_format
        return reshape_pred.is_features_to_groups_reshape(input_format)[0]

    def is_transpose_before_groups_to_features(self):
        if self.op != "Transpose":
            return False
        succs = list(self._graph.successors(self))
        if not succs or len(succs) != 1 or succs[0].op != "Reshape":
            return False
        return succs[0].is_groups_to_features_reshape(self.output_format)[0]

    def is_spatial_flatten_features_to_width(self):
        """
        This function checks if the reshape is of the form:
        [batch, channels, height, width] -> [batch, x, channels // x, height * width]
        """
        transpose = look_for_node(self._graph, self, [FwdChainNode(op="Transpose")])
        if not transpose:
            return False

        if transpose.get_transpose_perm() != [0, 2, 3, 1]:
            return False

        input_shapes = self.get_input_shapes()[RESHAPE_INPUT_ORDER.index("X")]
        output_shapes = transpose.get_output_shapes()
        if len(output_shapes) != 1:
            return False

        h, w, c = input_shapes[-3:]
        out_h, out_w, out_c = output_shapes[0][-3:]

        return c / out_c == out_w and out_h == w * h

    def get_spatial_flatten_features_to_width_info(self):
        transpose = look_for_node(self._graph, self, [FwdChainNode(op="Transpose")])
        return transpose.get_output_shapes(), [transpose]

    def is_input_to_attention_windows_reshape(self):
        """
        This function checks if the reshape block is of the form:
        [1, h, w, c] -> [(h / window_h) ** 2, window_h ** 2, c]

        """
        if self.op != "Reshape":
            return False

        _, chain, params = self.get_input_to_windows_info()
        if not chain:
            return False

        return (
            params["h"] == params["w"]  # input is symmetric
            and params["window_size"] // params["window_h"] == params["window_h"]  # window is symmetric
            and params["window_h"] ** 2 * params["num_windows"] == params["h"] * params["w"]
            and params["end_block_f_out"] == params["start_block_f_out"]  # output channels
        )

    def get_input_to_windows_info(self):
        params = {"width_features": True}
        chains = [
            [FwdChainNode("Transpose"), FwdChainNode("Reshape"), FwdChainNode("Reshape")],
            [FwdChainNode("Transpose"), FwdChainNode("Reshape"), FwdChainNode("Transpose"), FwdChainNode("Reshape")],
        ]
        chain = get_all_nodes_from_possible_chains(self._graph, self, chains)
        in_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        if chain:
            if not self.input_format:
                return [], [], {}
            params["h"] = in_shape[self.input_format.index(Dims.HEIGHT)]
            params["w"] = in_shape[self.input_format.index(Dims.WIDTH)]
            params["start_block_f_out"] = in_shape[self.input_format.index(Dims.CHANNELS)]
            params["window_h"] = self.get_output_shapes(convert_to_nhwc=False)[0][-2 if len(chain) != 4 else -1]

        else:
            chain = get_all_nodes_in_chain(self._graph, self, [FwdChainNode("Transpose"), FwdChainNode(op="Reshape")])
            reshape_shape = self.get_reshape_shapes()[0]
            if not chain or len(reshape_shape) != 6:
                return [], [], {}
            params["start_block_f_out"] = in_shape[-1]
            params["h"], params["w"] = np.prod(reshape_shape[1:3]), np.prod(reshape_shape[3:5])
            params["window_h"] = reshape_shape[2]

        chain_out_shapes = chain[-1].get_output_shapes(convert_to_nhwc=False)
        if len(chain_out_shapes[0]) != 3:
            return [], [], {}
        params["num_windows"], params["window_size"], params["end_block_f_out"] = chain_out_shapes[0]
        if len(chain) == 4:
            params["width_features"] = False
            params["window_size"], params["num_windows"] = params["num_windows"], params["window_size"]

        output_shapes = [[-1, *output_shape] for output_shape in chain_out_shapes]
        return output_shapes, chain, params

    def is_input_to_windows_chain_end(self):
        # The order of the chains is important cause they contain each other and we want the maximal
        chains = [
            [BwdChainNode("Transpose"), BwdChainNode("Reshape"), BwdChainNode("Transpose"), BwdChainNode("Reshape")],
            [BwdChainNode("Reshape"), BwdChainNode("Transpose"), BwdChainNode("Reshape")],
            [BwdChainNode("Transpose"), BwdChainNode("Reshape")],
        ]
        chain_start = get_node_from_possible_chains(self._graph, self, chains)
        return chain_start and chain_start.is_input_to_attention_windows_reshape()

    def is_attention_windows_to_input_reshape(self):
        """
        This function checks if the reshape block is of the form:
        [(h / window_h) ** 2, window_h ** 2, c] -> [-1, h, w, c]

        """
        if self.op != "Reshape":
            return False

        if self.input_format and self.input_format != [Dims.HEIGHT, Dims.WIDTH, Dims.CHANNELS]:
            return False

        _, chain, attention_params = self.get_windows_to_input_info()
        if not chain:
            return False

        return (
            attention_params["h"] == attention_params["w"]
            and attention_params["h"] ** 2 == attention_params["num_windows"] * attention_params["window_size"]
            and attention_params["window_size"] // attention_params["window_h"] == attention_params["window_h"]
            and attention_params["start_block_f_out"] == attention_params["end_block_f_out"]  # output channels
        )

    def get_windows_to_input_info(self):
        params = {"width_features": True, "flatten_end": False}
        h_idx, w_idx, f_idx = [1, 2, 3]
        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        if len(input_shape) != 3:
            return [], [], {}

        params["num_windows"], params["window_size"], params["start_block_f_out"] = input_shape
        chains = [
            [FwdChainNode("Reshape"), FwdChainNode("Transpose"), FwdChainNode("Reshape")],
            [FwdChainNode("Transpose"), FwdChainNode("Reshape"), FwdChainNode("Transpose"), FwdChainNode("Reshape")],
        ]
        chain = get_all_nodes_from_possible_chains(self._graph, self, chains)
        if not chain:
            chain = get_all_nodes_in_chain(self._graph, self, [FwdChainNode("Transpose"), FwdChainNode("Reshape")])
            reshape_shapes = self.get_reshape_shapes()[0]
            if not chain or len(reshape_shapes) != 6:
                return [], [], {}

            chain_out_shapes = chain[-1].get_output_shapes(convert_to_nhwc=False)
            params["window_h"] = reshape_shapes[3]
            params["h"] = np.prod(chain[0].get_output_shapes(convert_to_nhwc=False)[0][1:3])
            params["w"] = np.prod(chain[0].get_output_shapes(convert_to_nhwc=False)[0][3:5])
            params["end_block_f_out"] = chain_out_shapes[0][-1]
            if len(chain_out_shapes[0]) == 3:
                params["flatten_end"] = True
                chain_out_shapes = [[x[0], 1, *x[1:]] for x in chain_out_shapes]
        else:
            chain_out_shapes = chain[-1].get_output_shapes(convert_to_nhwc=False)
            if len(chain_out_shapes[0]) != 4:
                return [], [], {}

            if len(chain) == 4:
                h_idx, w_idx, f_idx = [2, 3, 1]
                params["width_features"] = False
                params["num_windows"], params["window_size"] = params["window_size"], params["num_windows"]
            params["h"] = chain_out_shapes[0][h_idx]
            params["w"] = chain_out_shapes[0][w_idx]
            params["end_block_f_out"] = chain_out_shapes[0][f_idx]
            params["window_h"] = chain[-1].get_input_shapes(convert_to_nhwc=False)[0][-2 if len(chain) != 4 else -1]

        shapes = [[shape[i] for i in [0, h_idx, w_idx, f_idx]] for shape in chain_out_shapes] if chain else []
        hn_shapes = [[-1, *shape[1:]] for shape in shapes]
        return hn_shapes, chain, params

    def is_windows_to_input_chain_end(self):
        chains = [
            [BwdChainNode("Transpose"), BwdChainNode("Reshape"), BwdChainNode("Transpose"), BwdChainNode("Reshape")],
            [BwdChainNode("Transpose"), BwdChainNode("Reshape"), BwdChainNode("Reshape")],
            [BwdChainNode("Transpose"), BwdChainNode("Reshape")],
        ]
        chain = get_all_nodes_from_possible_chains(self._graph, self, chains)
        output_format = [Dims.BATCH, Dims.HEIGHT, Dims.WIDTH, Dims.CHANNELS]
        if len(self.get_output_shapes(convert_to_nhwc=False)[0]) == 3:
            output_format = [Dims.BATCH, Dims.WIDTH, Dims.CHANNELS]
        elif chain and len(chain) == 4:
            output_format = [Dims.BATCH, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH]
        return chain and chain[-1].is_attention_windows_to_input_reshape(), output_format

    def is_groups_to_spatial_flatten(self):
        if self.op != "Reshape":
            return False

        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        output_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
        if len(input_shape) != 5 or len(output_shape) not in [3, 4]:
            return False
        transpose = look_for_node(self._graph, self, [FwdChainNode(op="Transpose")])
        return (
            transpose
            and transpose.get_transpose_perm() in [[0, 2, 1], [0, 1, 3, 2]]
            and (
                input_shape[1] == output_shape[-2]
                and input_shape[2] * input_shape[3] * input_shape[4] == output_shape[-1]
            )
        )

    def get_groups_to_spatial_flatten_info(self):
        transpose = look_for_node(self._graph, self, [FwdChainNode(op="Transpose")])
        if not transpose:
            return None, None
        return self.get_input_shapes(convert_to_nhwc=False)[0][2], transpose.get_output_shapes(), [transpose]

    def is_spatial_flatten_to_groups(self):
        if self.op != "Reshape":
            return False

        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        output_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
        if len(output_shape) != 5:
            return False
        _, groups, h, w, c = output_shape
        transpose = look_for_node(self._graph, self, [FwdChainNode(op="Transpose")])

        return (
            transpose
            and transpose.get_transpose_perm() == [0, 4, 1, 2, 3]
            and (c == input_shape[-1] and h == w and h * w * groups == input_shape[-2])
        )

    def get_spatial_flatten_to_groups_info(self):
        return self.get_output_shapes(convert_to_nhwc=False)[0][1], *self.get_groups_to_spatial_flatten_info()[1:]

    def is_spatial_flatten_and_groups_to_features(self):
        if self.op != "Reshape":
            return False

        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        output_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
        if len(input_shape) != 4 or len(output_shape) != 5:
            return False

        return (input_shape[1] * input_shape[2] == output_shape[1]) and (
            input_shape[3] == output_shape[2] * output_shape[3] * output_shape[4]
        )

    def is_flatten_width_over_features_reshape(self):
        """
        This function checks if the reshape block flattening the width dimensions over features dimensions.
        e.g.: [batch, h, w, c] -> [batch, h, 1, w *c], [batch, c, w] -> [batch, w * c]
        """
        if self.op != "Reshape":
            return False

        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        output_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
        if len(input_shape) != 3 or len(output_shape) not in [2, 4]:
            return False

        pred = next(iter(self.graph.predecessors(self)))
        # input and output formats should be the same, flattening and not transposing
        if not pred.output_format or any(dim not in pred.output_format for dim in (Dims.WIDTH, Dims.CHANNELS)):
            return False

        if Dims.HEIGHT in pred.output_format and output_shape[pred.output_format.index(Dims.HEIGHT)] != 1:
            return False

        in_width_dim = pred.output_format.index(Dims.WIDTH)
        in_features_dim = pred.output_format.index(Dims.CHANNELS)

        return input_shape[in_width_dim] * input_shape[in_features_dim] == output_shape[-1] and output_shape[-2] == 1

    def get_flatten_width_over_features_reshape_info(self):
        pred = next(iter(self.graph.predecessors(self)))
        return [1, pred.get_output_shapes(convert_to_nhwc=False)[0][pred.output_format.index(Dims.WIDTH)]], []

    def is_conv3d_to_rank2_block(self):
        if self.op not in ["Reshape", "Transpose"]:
            return False

        # edge case of qwen2_vl_vision: conv3d-transpose-reshape, output is [W, C]
        chains = [[BwdChainNode("Transpose"), BwdChainNode("Conv")], [BwdChainNode("Conv")]]
        conv = get_node_from_possible_chains(self._graph, self, chains)
        if not conv or not conv.output_format:
            return False

        conv_out_shape = conv.get_output_shapes(convert_to_nhwc=False)[0]
        wanted_shape = [conv_out_shape[conv.output_format.index(dim)] for dim in (Dims.WIDTH, Dims.CHANNELS)]
        end = self if self.op == "Reshape" else next(iter(self._graph.successors(self)))
        out_shape = end.get_output_shapes(convert_to_nhwc=False)[0]
        return (
            conv.is_conv3d()
            and end.op == "Reshape"
            and "LayerNormalization" in [x.op for x in self._graph.successors(end)]
            and out_shape == wanted_shape
        )

    def get_spatial_flatten_and_groups_to_features_info(self):
        return self.get_output_shapes(convert_to_nhwc=False)[0][2], self.get_output_shapes(), []

    def is_partial_groups_to_spatial_flatten(self):
        if self.op != "Reshape":
            return False

        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        output_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
        if len(input_shape) != 5 or len(output_shape) != 4:
            return False

        transpose = look_for_node(self._graph, self, [FwdChainNode(op="Transpose")])
        return (
            transpose
            and transpose.get_transpose_perm() == [0, 1, 3, 2]
            and np.prod(input_shape[2:]) == output_shape[-1]
            and np.prod(output_shape[1:3]) == input_shape[1]
        )

    def get_partial_groups_to_spatial_flatten_info(self):
        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        groups = input_shape[2]
        transpose = look_for_node(self._graph, self, [FwdChainNode(op="Transpose")])
        return groups, transpose.get_output_shapes(), [transpose]

    def is_nhwc_to_nchw_transpose(self):
        if self.op != "Transpose":
            return False

        return self.input_format == [Dims.BATCH, Dims.HEIGHT, Dims.WIDTH, Dims.CHANNELS] and self.output_format == [
            Dims.BATCH,
            Dims.CHANNELS,
            Dims.HEIGHT,
            Dims.WIDTH,
        ]

    def is_nchw_to_nhwc_transpose(self):
        if self.op != "Transpose":
            return False

        return self.input_format == [Dims.BATCH, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH] and self.output_format == [
            Dims.BATCH,
            Dims.HEIGHT,
            Dims.WIDTH,
            Dims.CHANNELS,
        ]

    def is_grouped_reduce_max(self):
        if self.op != "MaxPool":
            return False

        pred = next(iter(self.graph.predecessors(self)))
        if pred.output_format != [Dims.BATCH, Dims.GROUPS, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH]:
            return False

        input_shapes = self.get_input_shapes(convert_to_nhwc=False)
        if not input_shapes:
            return False
        input_shape = input_shapes[0]

        kernel_shape = self.get_kernel_shape()
        if len(kernel_shape) != 3 or kernel_shape[0] != input_shape[2] or kernel_shape[1] != 1 or kernel_shape[2] != 1:
            return False

        padding, _, _, _ = self.get_vertex_padding()
        if padding != PaddingType.valid:
            return False

        strides = self.get_attribute_by_name("strides")[0].ints
        return len(strides) == 3 and strides[0] == input_shape[2] and strides[1] == strides[2] == 1

    def get_grouped_reduce_max_info(self):
        groups = self.get_output_shapes(convert_to_nhwc=False)[0][1]
        return groups, []

    def is_spatial_flatten_reshape(self):
        if self.op not in ["Reshape", "Squeeze"]:
            return False

        if self.is_spatial_flatten_reshape_after_group_norm()[0]:
            return True

        if self._is_spatial_flatten_with_features_to_heads_reshape():
            return True

        if self.is_windows_to_input_chain_end()[0]:
            return False

        input_shapes = self.get_input_shapes(convert_to_nhwc=False)
        output_shapes = self.get_output_shapes(convert_to_nhwc=False)
        if not input_shapes or not output_shapes:
            return False

        input_shape, output_shape = input_shapes[0], output_shapes[0]
        if len(input_shape) != 4:
            return False

        pred = next(iter(self.graph.predecessors(self)))
        if len(output_shape) == 3:
            if pred.output_format == [Dims.BATCH, Dims.HEIGHT, Dims.WIDTH, Dims.CHANNELS]:
                return input_shape[1] * input_shape[2] == output_shape[1] and input_shape[3] == output_shape[2]

            if pred.output_format and set(pred.output_format[-2:]) != set([Dims.HEIGHT, Dims.WIDTH]):
                return False

            return input_shape[2] * input_shape[3] == output_shape[2] and input_shape[1] == output_shape[1]

        # output format is [channels, height * width] - batch is flattened
        if len(output_shape) == 2 and pred.output_format == [Dims.BATCH, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH]:
            return input_shape[2] * input_shape[3] == output_shape[1] and input_shape[1] == output_shape[0]

        # output format is [width, groups * channels] - batch is flattened
        if len(output_shape) == 2 and pred.output_format == [Dims.WIDTH, Dims.BATCH, Dims.GROUPS, Dims.CHANNELS]:
            return input_shape[2] * input_shape[3] == output_shape[1] and input_shape[0] == output_shape[0]

        if pred.output_format:
            input_width_index = pred.output_format.index(Dims.WIDTH)
            # output index is offset by the difference in rank between input and output
            output_width_index = input_width_index + len(output_shape) - len(input_shape)

            if output_width_index >= len(output_shape):
                return False
            return output_shape[output_width_index] == np.prod(
                input_shape[input_width_index - 1 : input_width_index + 1],
            )

        if len(input_shape) == len(output_shape) == 4:
            return input_shape[1] * input_shape[2] == output_shape[2]

        return False

    def _is_spatial_flatten_with_features_to_heads_reshape(self):
        if self.is_spatial_flatten_features_to_width():
            return False

        pred = next(iter(self.graph.predecessors(self)))

        if pred.output_format not in [
            [Dims.BATCH, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH],
            [Dims.BATCH, Dims.GROUPS, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH],
        ]:
            return False

        input_shapes = self.get_input_shapes(convert_to_nhwc=False)
        output_shapes = self.get_output_shapes(convert_to_nhwc=False)
        if not input_shapes or not output_shapes:
            return False

        input_shape, output_shape = input_shapes[0], output_shapes[0]
        if len(input_shape) not in [4, 5] or len(output_shape) != 4:
            return False

        if input_shape[0] != output_shape[0] or input_shape[-2] * input_shape[-1] != output_shape[-1]:
            return False

        return (
            input_shapes != output_shapes
            and (len(input_shape) == 4 and input_shape[1] == output_shape[1] * output_shape[2])
            or (len(input_shape) == 5 and input_shape[:-2] == output_shape[:-1])
        )

    def _is_spatial_unflatten_with_g_to_f_reshape(self):
        # [batch, groups, channels, width] -> [batch, groups * channels, height, width]
        pred = next(iter(self.graph.predecessors(self)))
        if not pred.output_format:
            return False

        input_format = pred.output_format
        if self.op == "Transpose":
            reshape_node = look_for_node(self.graph, self, [FwdChainNode(op="Reshape")])
            if not reshape_node:
                return False
            perm = self.get_transpose_perm()
            input_format = [input_format[i] for i in perm]
        else:
            reshape_node = self

        if input_format != [Dims.BATCH, Dims.GROUPS, Dims.CHANNELS, Dims.WIDTH]:
            return False

        input_shapes = reshape_node.get_input_shapes(convert_to_nhwc=False)
        output_shapes = reshape_node.get_output_shapes(convert_to_nhwc=False)
        if not input_shapes or not output_shapes:
            return False

        input_shape, output_shape = input_shapes[0], output_shapes[0]
        if len(input_shape) != 4 or len(output_shape) != 4:
            return False

        return (
            input_shape[0] == output_shape[0]
            and output_shape[1] == input_shape[1] * input_shape[2]
            and output_shape[2] * output_shape[3] == input_shape[3]
        )

    def is_spatial_flatten_reshape_after_group_norm(self):
        if self.op != "Reshape":
            return False, None

        chain = [BwdChainNode("InstanceNormalization"), BwdChainNode("Reshape")]
        first_reshape = look_for_node(self._graph, self, chain)
        if first_reshape is None or not first_reshape.is_group_norm_reshape()[0]:
            return False, None

        input_shape = first_reshape.get_input_shapes(convert_to_nhwc=False)[0]
        output_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
        if len(input_shape) != 4 or len(output_shape) != 3:
            return False, None

        spatial_flatten_shapes_cond = (
            input_shape[2] * input_shape[3] == output_shape[2] and input_shape[1] == output_shape[1]
        )

        return spatial_flatten_shapes_cond, first_reshape

    def is_features_to_groups_reshape(self, input_format):
        if self.is_shuffle() or self.is_reshape_before_einsum():
            return False, None

        input_shapes = self.get_input_shapes(convert_to_nhwc=False)
        output_shapes = self.get_output_shapes(convert_to_nhwc=False)
        if not (input_shapes and output_shapes and input_format) or Dims.CHANNELS not in input_format:
            return False, None

        in_shape, out_shape = input_shapes[0], output_shapes[0]
        default_out_format = input_format.copy()
        default_out_format.insert(input_format.index(Dims.CHANNELS), Dims.GROUPS)
        channel_i = input_format.index(Dims.CHANNELS)
        shapes_diff = len(out_shape) - len(in_shape)
        if shapes_diff > 0 and in_shape[:channel_i] != out_shape[:channel_i]:
            return False, None

        end_i = channel_i + shapes_diff + 1
        if shapes_diff == 0:
            if channel_i == len(in_shape) - 1 and in_shape[channel_i] == np.prod(out_shape[channel_i - 1 :]):
                # edge case of groups dim replacing batch dim (yolov5s_c3tr_simp.onnx, Reshape_169)
                if input_format in [[Dims.WIDTH, Dims.BATCH, Dims.CHANNELS], [Dims.WIDTH, Dims.GROUPS, Dims.CHANNELS]]:
                    return True, [Dims.WIDTH, Dims.GROUPS, Dims.CHANNELS]
                if input_format == [Dims.BATCH, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH]:
                    if in_shape[2] == 1 and in_shape[3] == out_shape[3]:
                        default_out_format.remove(Dims.HEIGHT)
                        return True, default_out_format

        elif (
            in_shape[channel_i] == np.prod(out_shape[channel_i:end_i])
            and in_shape[channel_i + 1 :] == out_shape[end_i:]
        ):
            if shapes_diff == 1 and Dims.GROUPS not in input_format:
                # input ormat of [width, channels] will have batch instead of groups
                if input_format == [Dims.WIDTH, Dims.CHANNELS]:
                    default_out_format[default_out_format.index(Dims.GROUPS)] = Dims.BATCH
                return True, default_out_format

            chains = [[FwdChainNode("Gather")], [FwdChainNode("Split")], [FwdChainNode("Slice")]]
            chains = chains + [[FwdChainNode("Transpose")] + chain for chain in chains]
            nodes = get_all_nodes_from_possible_chains(self._graph, self, chains)
            if nodes:
                stack_idx = nodes[-1].get_axis()
                if nodes[-1].op == "Slice":
                    stack_idx = nodes[-1].get_slices_args()[0]["axes"][0]
                if len(nodes) > 1:
                    perm = nodes[0].get_transpose_perm()
                    stack_idx = perm[stack_idx]
                default_out_format = default_out_format if shapes_diff == 2 else input_format.copy()
                default_out_format.insert(stack_idx, Dims.STACK)
                return True, default_out_format

        return False, None

    def is_groups_to_features_reshape(self, input_format):
        input_shapes = self.get_input_shapes(convert_to_nhwc=False)
        output_shapes = self.get_output_shapes(convert_to_nhwc=False)
        if not (input_shapes and output_shapes and input_format) or Dims.GROUPS not in input_format:
            return False, None

        groups_dim = Dims.STACK if Dims.STACK in input_format else Dims.GROUPS
        default_out_format = input_format.copy()
        default_out_format.remove(groups_dim)
        input_shape, output_shape = input_shapes[0], output_shapes[0]
        if len(input_shape) != len(input_format):
            return False, None

        if groups_dim == Dims.GROUPS:
            first_idx = input_format.index(Dims.GROUPS)
            second_idx = input_format.index(Dims.CHANNELS)
        else:
            first_idx = input_format.index(Dims.CHANNELS)
            second_idx = input_format.index(Dims.STACK)

        if (
            not self.is_spatial_flatten_reshape()
            and first_idx + 1 == second_idx
            and input_shape[:first_idx] == output_shape[:first_idx]
        ):
            if (
                input_shape[first_idx] * input_shape[second_idx] == output_shape[first_idx]
                and input_shape[second_idx + 1 :] == output_shape[second_idx:]
                and len(input_shape) == len(output_shape) + 1
            ):
                return True, default_out_format

            if input_format == [Dims.WIDTH, Dims.GROUPS, Dims.CHANNELS] and output_shape[1] == 1:
                return True, [Dims.WIDTH, Dims.BATCH, Dims.CHANNELS]

        elif self.is_conv3d_to_rank2_block():
            return True, [Dims.WIDTH, Dims.CHANNELS]

        return False, None

    def is_nhw_to_nchw_reshape(self, input_format):
        if input_format != [Dims.BATCH, Dims.HEIGHT, Dims.WIDTH]:
            return False
        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        output_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
        return len(input_shape) == 3 and len(output_shape) == 4 and input_shape == output_shape[1:]

    def is_reshape_before_bbox_decoder(self):
        if self.op not in ["Reshape"]:
            return False

        bbox_decoder_concat = look_for_node(self._graph, self, [FwdChainNode(op="Concat")])
        if bbox_decoder_concat is not None:
            children = list(self._graph.predecessors(bbox_decoder_concat))
            if len(children) > 1 and all(x.is_spatial_flatten_reshape() for x in children):
                return True

        return False

    def is_spatial_unflatten(self):
        if self.op not in ["Reshape", "Transpose"]:
            return False

        if self._is_spatial_unflatten_with_g_to_f_reshape():
            return True

        chains = [[FwdChainNode(op="Reshape")], [FwdChainNode(op="BatchNormalization"), FwdChainNode(op="Reshape")]]
        reshape_node = self if self.op == "Reshape" else get_node_from_possible_chains(self._graph, self, chains)
        if reshape_node is None:
            return False

        input_shapes = reshape_node.get_input_shapes(convert_to_nhwc=False)
        output_shapes = reshape_node.get_output_shapes(convert_to_nhwc=False)
        if not input_shapes or not output_shapes:
            return False
        input_shape, output_shape = input_shapes[0], output_shapes[0]

        input_format = next(iter(self._graph.predecessors(self))).output_format
        if not input_format:
            return False

        if self.op == "Transpose":
            perm = self.get_transpose_perm()
            input_format = [input_format[i] for i in perm]

        if Dims.WIDTH not in input_format or len(input_format) != len(input_shape):
            return False

        if len(output_shape) - len(input_shape) == 2 and len(input_format) - len(input_shape) == 1:
            input_shape = [1, *input_shape]

        input_width_index = input_format.index(Dims.WIDTH)
        # output index is offset by the difference in rank between input and output
        output_width_index = input_width_index + len(output_shape) - len(input_shape)

        if output_width_index >= len(output_shape):
            return False

        return input_shape[input_width_index] == np.prod(output_shape[output_width_index - 1 : output_width_index + 1])

    def is_reshape_transpose_expand_height_dim(self):
        if self.op != "Reshape":
            return False

        transpose = look_for_node(self._graph, self, [FwdChainNode(op="Transpose")])
        if not transpose:
            return False
        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        output_shape = transpose.get_output_shapes(convert_to_nhwc=False)[0]
        if len(input_shape) != 3:
            return False
        return [input_shape[0], 1, input_shape[1], input_shape[2]] == output_shape

    def is_flatten_height_stack_reshape(self):
        input_shapes = self.get_input_shapes(convert_to_nhwc=False)
        output_shapes = self.get_output_shapes(convert_to_nhwc=False)
        if not input_shapes or not output_shapes or len(input_shapes) != 1 or len(output_shapes) != 1:
            return False

        if len(input_shapes[0]) != 6 or len(output_shapes[0]) != 4:
            return False

        _, g_in, h_in, w_in, f_in, s_in = input_shapes[0]
        _, g_out, w_out, f_out = output_shapes[0]
        return g_in == g_out and w_in * h_in == w_out and f_in * s_in == f_out

    def get_flatten_height_stack_reshape_info(self):
        output_shapes = self.get_output_shapes()
        spatial_reshape_sizes = output_shapes[0][1:3]
        return output_shapes, spatial_reshape_sizes

    def is_features_to_stack(self):
        output_shapes = self.get_output_shapes(convert_to_nhwc=False)
        input_shapes = self.get_input_shapes(convert_to_nhwc=False)

        if self.op == "Reshape":
            if not output_shapes or not input_shapes or len(output_shapes) != 1 or len(input_shapes) != 1:
                return False

            if len(output_shapes[0]) != 6 or len(input_shapes[0]) != 4:
                return False

            g_in, w_in, f_in = input_shapes[0][-3:]
            _, g_out, h_out, w_out, f_out, s_out = output_shapes[0]
            return g_in == g_out and w_out * h_out == w_in and f_out * s_out == f_in

        # spatial is not flattened
        if len(output_shapes[0]) != 6 or len(input_shapes[0]) != 5:
            return False

        g_in, h_in, w_in, f_in = input_shapes[0][-4:]
        _, g_out, h_out, w_out, f_out, s_out = output_shapes[0]
        return g_in == g_out and w_out == w_in and h_in == h_out and f_out * s_out == f_in

    def get_spatial_unflatten_features_to_groups_info(self):
        output_shape = self.get_output_shapes(False)[0]
        shapes = [[-1, *output_shape[2:4], int(np.prod(output_shape[-2:]))]]
        spatial_reshape_sizes = output_shape[2:4]
        return shapes, spatial_reshape_sizes

    def get_spatial_unflatten_reshape_info(self):
        if self.op != "Reshape":
            consumed_vertices = [look_for_node(self._graph, self, [FwdChainNode(op="Reshape")])]
        else:
            transpose = look_for_node(self._graph, self, [FwdChainNode(op="Transpose")])
            consumed_vertices = [self, transpose] if transpose else [self]

        if len(consumed_vertices) < 1:
            raise UnexpectedNodeError(f"Failed to find reshape node in format conversion layer near {self.name}.")

        output_shape = consumed_vertices[-1].get_output_shapes()[0]

        spatial_reshape_sizes = output_shape[1:3] if len(output_shape) == 4 else output_shape[-2:]
        return consumed_vertices, [output_shape], spatial_reshape_sizes

    def is_depth_to_space_reshape_transpose(self):
        if self.op != "Reshape":
            return False, []

        _, consumed_vertices = self.get_reshape_shapes()
        input_shapes = self.get_input_shapes(convert_to_nhwc=False)
        output_shapes = self.get_output_shapes(convert_to_nhwc=False)
        if not input_shapes or len(output_shapes) != 1:
            return False, []
        input_shape, output_shape = input_shapes[0], output_shapes[0]

        transformer_encoder_ranks_cond = len(input_shape) == 3 and len(output_shape) == 5
        transformer_decoder_ranks_cond = len(input_shape) == 3 and len(output_shape) == 4
        multihead_attn_ranks_cond = len(input_shape) == 3 and len(output_shape) == 3

        if not transformer_encoder_ranks_cond and not transformer_decoder_ranks_cond and not multihead_attn_ranks_cond:
            return False, []

        transformer_encoder_shapes_cond = transformer_encoder_ranks_cond and (
            input_shape[0] == output_shape[0]
            and input_shape[1] == output_shape[1]
            and input_shape[2] == output_shape[2] * output_shape[3] * output_shape[4]
        )

        transformer_decoder_shapes_cond = transformer_decoder_ranks_cond and (
            input_shape[0] == output_shape[0]
            and input_shape[1] == output_shape[1]
            and input_shape[2] == output_shape[2] * output_shape[3]
        )

        multihead_attn_spatial_flatten_cond = multihead_attn_ranks_cond and (
            input_shape[0] == output_shape[0] and input_shape[2] == output_shape[1] * output_shape[2]
        )

        if not (
            transformer_encoder_shapes_cond or transformer_decoder_shapes_cond or multihead_attn_spatial_flatten_cond
        ):
            return False, []

        succs = list(self._graph.successors(self))
        if len(succs) != 1 or succs[0].op != "Transpose":
            return False, []

        perm = succs[0].get_attribute_by_name("perm")[0].ints
        encoder_perm_cond = transformer_encoder_ranks_cond and perm == [2, 0, 3, 1, 4]
        transformer_decoder_perm_cond = transformer_decoder_ranks_cond and perm in [[0, 2, 1, 3], [0, 2, 3, 1]]
        multihead_attn_perm_cond = multihead_attn_spatial_flatten_cond and perm in [[1, 0, 2], [1, 2, 0]]
        is_d2s_struct = encoder_perm_cond or transformer_decoder_perm_cond or multihead_attn_perm_cond

        return is_d2s_struct, consumed_vertices + succs

    def is_space_to_depth_transpose_reshape(self):
        if self.op != "Transpose":
            return False, []

        perm = self.get_transpose_perm()
        encoder_perm_cond = perm == [0, 2, 1, 3]
        multihead_attn_perm_cond = perm == [1, 0, 2]
        if not encoder_perm_cond and not multihead_attn_perm_cond:
            return False, []

        succs = list(self._graph.successors(self))
        if len(succs) != 1 or succs[0].op != "Reshape":
            return False, []

        _, consumed_vertices = succs[0].get_reshape_shapes()
        input_shapes = succs[0].get_input_shapes(convert_to_nhwc=False)
        output_shapes = succs[0].get_output_shapes(convert_to_nhwc=False)
        if not input_shapes or len(output_shapes) != 1:
            return False, []
        input_shape, output_shape = input_shapes[0], output_shapes[0]

        encoder_rank_cond = encoder_perm_cond and len(input_shape) == 4 and len(output_shape) == 3
        multihead_attn_rank_cond = multihead_attn_perm_cond and len(input_shape) == 3 and len(output_shape) in [2, 3]
        if not encoder_rank_cond and not multihead_attn_rank_cond:
            return False, []

        encoder_shape_cond = encoder_rank_cond and (
            input_shape[0] == output_shape[0]
            and input_shape[1] == output_shape[1]
            and input_shape[2] * input_shape[3] == output_shape[2]
        )

        multihead_attn_shape_cond = multihead_attn_rank_cond and (
            input_shape[0] == output_shape[0] and input_shape[1] * input_shape[2] == output_shape[-1]
        )

        is_s2d_struct = encoder_shape_cond or multihead_attn_shape_cond
        return is_s2d_struct, consumed_vertices + succs

    def get_gather_index(self):
        index_nodes = get_all_nodes_from_possible_chains(
            self._graph,
            self,
            [[BwdChainNode(op="Constant")], [BwdChainNode(op="Cast"), BwdChainNode(op="Constant")]],
        )
        indices_vars = self.graph.values_by_vertex_name.get(self.name, {})
        if index_nodes:
            index = index_nodes[-1].parse_raw_data(cast_to_int=True).tolist()
        elif self._info.input[-1] in indices_vars:
            index = indices_vars[self._info.input[-1]].tolist()
        else:
            raise UnsupportedGatherLayerError("Can't find index")

        index = [index] if isinstance(index, int) else index
        return index, index_nodes if index_nodes else []

    def is_gather_slice(self):
        if self.op not in GATHER_OPS:
            return False

        axis = self.get_axis()
        input_shapes = self.get_input_shapes(convert_to_nhwc=False)
        if not input_shapes or axis >= len(input_shapes[0]):
            return False

        # 4 conditions specific to transformer encoders
        rank = len(input_shapes[0])
        index, _ = self.get_gather_index()
        pred = next(iter(self.graph.predecessors(self)))

        if self.is_null_unsqueeze_gather(pred):
            return False

        if pred.output_format:
            return pred.output_format[axis] != Dims.BATCH and len(index) <= 2

        if rank == 3:
            if axis == 1 and index[0] == 0:
                return True

            reshape_node = look_for_node(self._graph, self, [BwdChainNode(op="Transpose"), BwdChainNode(op="Reshape")])
            if not reshape_node or not reshape_node.is_depth_to_space_reshape_transpose()[0]:
                return False

            reshape_output_shapes = reshape_node.get_output_shapes()
            if not reshape_output_shapes:
                return False

            num_of_splits = reshape_output_shapes[0][2]
            return index[0] < num_of_splits

        return len(index) <= 2

    def get_slice_info(self):
        slice_args, consumed_vertices = self.get_slices_args()
        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]

        f_slice = [0, 0, 1]
        h_slice = [0, 0, 1]
        w_slice = [0, 0, 1]
        succs = list(self.graph.successors(self))
        is_succ_spatial_flatten_complex = succs and succs[0].is_spatial_flatten_complex()[0]
        groups = 1
        if self.output_format and Dims.GROUPS in self.output_format:
            groups = next(x for i, x in enumerate(input_shape) if self.output_format[i] == Dims.GROUPS)

        for axis, start, end, steps in zip(
            slice_args["axes"],
            slice_args["starts"],
            slice_args["ends"],
            slice_args["steps"],
        ):
            if start < 0:
                start += input_shape[axis]
            end = min(end, input_shape[axis])
            if end < 0:
                end += input_shape[axis]

            dim_args = [start, end, steps]
            dim = self.output_format[axis] if self.output_format else None

            if dim == Dims.BATCH or (dim is None and axis == 0):
                default_logger().warning(
                    f"Ignoring slices on batch dim in layer {self.name} which are not supported. "
                    "It's recommended to run evaluation on the full-precision model before "
                    "optimization to verify accuracy is still good enough.",
                )
            elif dim in [Dims.CHANNELS, Dims.STACK] or (dim is None and axis in [-1, 1]):
                if dim_args[2] != 1 and groups == 1:
                    groups = input_shape[axis] // dim_args[2]
                    dim_args[1] = dim_args[0] + 1
                    dim_args[2] = 1
                elif dim == Dims.STACK:
                    slice_size = groups * input_shape[self.output_format.index(Dims.CHANNELS)]
                    dim_args = [start * slice_size, end * slice_size, steps]
                    groups = 1
                f_slice = dim_args
            elif dim == Dims.HEIGHT or (dim is None and axis == 2):
                h_slice = dim_args
            elif dim == Dims.WIDTH or (dim is None and axis == 3 or (axis == 1 and is_succ_spatial_flatten_complex)):
                w_slice = dim_args
            else:
                raise UnsupportedSliceLayerError(
                    f"Failed to create slice layer at vertex {self.name}. Slice on axis {axis} is not supported",
                )

        if any(curr_slice[-1] not in [1, 2] for curr_slice in [h_slice, w_slice, f_slice]):
            raise UnsupportedSliceLayerError(f"Slice with stride bigger than 2 in layer {self.name} is not supported.")
        if any(spatial != [0, 0, 1] for spatial in [h_slice, w_slice]):
            groups = 1  # ignore groups in case of spatial slice

        return h_slice, w_slice, f_slice, groups, consumed_vertices

    def is_grouped_channels_pad(self):
        # edge case: efficientvit - padding on features with groups
        if self.op not in PAD_OPS:
            return False

        padding, pads, _, _ = self.get_vertex_padding()
        if padding != TemporaryPaddingType.external_undecided:
            return False

        if self.output_format == [Dims.BATCH, Dims.GROUPS, Dims.WIDTH, Dims.CHANNELS] and pads == [0, 0, 0, 1, 0, 0]:
            return True

        return self.output_format == [Dims.BATCH, Dims.GROUPS, Dims.CHANNELS, Dims.WIDTH] and pads == [0, 1, 0, 0, 0, 0]

    def get_grouped_pad_to_conv_info(self):
        _, _, pad_value, consumed_vertices = self.get_vertex_padding()
        dim_to_shape = dict(
            zip(self.output_format, self.get_input_shapes(convert_to_nhwc=False)[0]),
        )
        f_in = dim_to_shape[Dims.CHANNELS]
        groups = dim_to_shape.get(Dims.GROUPS, 1)

        bias = np.zeros(f_in + 1)
        bias[-1] = pad_value
        bias = np.tile(bias, groups)
        kernel = np.identity(f_in)
        kernel = np.concatenate([kernel, np.zeros((f_in, 1))], axis=-1)
        kernel = np.tile(kernel, groups)
        kernel = kernel.reshape((1, 1, *kernel.shape))

        return kernel, bias, groups, consumed_vertices

    def is_channels_gather_to_conv(self):
        if self.op not in GATHER_OPS:
            return False
        index, _ = self.get_gather_index()
        axis = self.get_axis()
        return len(index) > 2 and axis == 1

    def is_supported_gather(self):
        return self.is_gather_slice() or self.is_channels_gather_to_conv() or self.is_grouped_reduce_sum_gather()

    def get_channels_gather_to_conv_kernel(self):
        indices, _ = self.get_gather_index()
        f_in = self.get_input_shapes()[0][-1]
        f_out = self.get_output_shapes()[0][-1]
        kernel = np.zeros((1, 1, f_in, f_out))
        for index, orig_index in enumerate(indices):
            kernel[:, :, orig_index, index] = 1
        return kernel

    def get_gather_slice_info(self):
        index, index_nodes = self.get_gather_index()
        axis = self.get_axis()
        input_shapes = self.get_input_shapes(convert_to_nhwc=False)
        if not input_shapes:
            raise UnsupportedSliceLayerError(f"Unable to determine slice arguments for node {self.name}")

        rank = len(input_shapes[0])
        h_slice, w_slice, f_slice = [0, 0], [0, 0], [0, 0]
        groups = 1
        pred = next(iter(self.graph.predecessors(self)))
        dim = pred.output_format[axis] if pred.output_format else None

        if rank == 4 or self.output_format:
            single_dim = len(index) == 1
            if dim == Dims.CHANNELS or (dim is None and axis == 1):
                f_slice = [index[0], index[0] + 1] if single_dim else [index[0], index[1]]
            elif dim == Dims.HEIGHT or (dim is None and axis == 2):
                h_slice = [index[0], index[0] + 1] if single_dim else [index[0], index[1]]
            elif dim == Dims.WIDTH or (dim is None and axis == 3):
                w_slice = [index[0], index[0] + 1] if single_dim else [index[0], index[1]]
            elif dim == Dims.STACK and single_dim:
                pred_out_shape = pred.get_output_shapes(convert_to_nhwc=False)[0]
                f_per_stack = np.prod(
                    [
                        pred_out_shape[i]
                        for i, dim in enumerate(pred.output_format)
                        if dim in [Dims.GROUPS, Dims.CHANNELS]
                    ],
                )
                f_per_stack = int(f_per_stack)
                if axis == 0:  # stack dim is first so this is a regular slice
                    f_slice = [index[0] * f_per_stack, (index[0] + 1) * f_per_stack]
                else:  # grouped slice
                    f_slice = [index[0], index[0] + 1]
                    groups = f_per_stack
            elif dim == Dims.BATCH or (dim is None and axis == 0):
                raise UnsupportedSliceLayerError(
                    "Gather operation on the batch dimension is not supported. Please "
                    "modify the model so that the batch dimension remains intact.",
                )
            else:
                raise UnsupportedSliceLayerError(
                    f"Unsupported slice arguments (input shape rank={rank}, axis={axis}, "
                    f"index={index}) for node {self.name}",
                )

        elif rank == 3 and index == [0] and axis == 1:
            # specific edge case for rank 3 slicing in transformer encoders
            w_slice = [0, 1]
        elif rank == 5:
            # K, Q, V slices in transformer encoders models
            slice_length = input_shapes[0][2] * input_shapes[0][4]
            f_slice = [index[0] * slice_length, (index[0] + 1) * slice_length]
        else:
            raise UnsupportedSliceLayerError(
                f"Unsupported slice arguments (input shape rank={rank}, axis={axis}, "
                f"index={index}) for node {self.name}",
            )

        return h_slice, w_slice, f_slice, groups, index_nodes

    def is_grouped_reduce_sum_gather(self):
        if self.op not in GATHER_OPS:
            return False

        index, _ = self.get_gather_index()
        axis = self.get_axis()
        if index != [0] or axis != 1:
            return False

        slices = list(self.graph.successors(self))
        if any(x.op not in SLICE_OPS for x in slices):
            return False

        slices_axes = [x.get_axes_information() for x in slices]
        if any(x != [1] for x in slices_axes):
            return False

        if any(len(list(self.graph.successors(x))) != 1 for x in slices):
            return False

        first_wave_reduce_sums = [next(iter(self.graph.successors(x))) for x in slices]
        if len(first_wave_reduce_sums) != len(slices):
            return False

        if any(len(list(self.graph.successors(x))) != 1 for x in first_wave_reduce_sums):
            return False

        second_wave_reduce_sums = [next(iter(self.graph.successors(x))) for x in first_wave_reduce_sums]
        if len(second_wave_reduce_sums) != len(first_wave_reduce_sums):
            return False

        first_wave_axes = [x.get_axes_information() for x in first_wave_reduce_sums]
        if any(x != [-1] for x in first_wave_axes):
            return False

        second_wave_axes = [x.get_axes_information() for x in second_wave_reduce_sums]
        if any(x != [-2] for x in second_wave_axes):
            return False

        all_keepdims = [x.get_keepdims() for x in first_wave_reduce_sums + second_wave_reduce_sums]
        if not all(all_keepdims):
            return False

        final_concat = look_for_node(
            self.graph,
            self,
            [
                FwdChainNode(op="Slice"),
                FwdChainNode(op="ReduceSum"),
                FwdChainNode(op="ReduceSum"),
                FwdChainNode(op="Concat"),
            ],
        )
        if final_concat is None or final_concat.get_axis() != 2:
            return False

        return list(self.graph.predecessors(final_concat)) == second_wave_reduce_sums

    def is_reducing_rank_gather(self, pred):
        if self.op not in GATHER_OPS or not pred.output_format:
            return False

        if pred.op in INPUT_OPS:
            pred_out_shape = pred.graph.tensor_shapes_by_vertex_name[pred.name]
        else:
            pred_out_shape = pred.get_output_shapes(convert_to_nhwc=False)[0]
        gather_out_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
        return len(pred_out_shape) - len(gather_out_shape) == 1

    def is_input_gather_increasing_rank(self, pred):
        if self.op not in GATHER_OPS or not pred.output_format or pred.op not in INPUT_OPS:
            return False

        pred_out_shape = pred.graph.tensor_shapes_by_vertex_name[pred.name]
        gather_out_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
        return len(gather_out_shape) - len(pred_out_shape) == 1

    def is_null_unsqueeze_gather(self, pred):
        """
        Check if the gather is preceded by an unsqueeze node with the same axis.
        """
        if pred.op not in ["Unsqueeze", "Transpose"] or self.op != "Gather":
            return False

        unsqueeze = pred
        if pred.op == "Transpose":
            pred_of_pred = look_for_node(self.graph, pred, [BwdChainNode("Unsqueeze")])
            if not pred_of_pred:
                return False
            unsqueeze = pred_of_pred

        unsqueeze_cond = unsqueeze.get_axes_information()[0] == self.get_axis()
        transpose_cond = pred.op == "Transpose" and pred.get_transpose_perm()[0] == 0
        return unsqueeze_cond and transpose_cond

    def is_unsqueeze_concat_gather(self):
        """
        Check if the gather is preceded by concat and unsqueeze nodes with the same axis.
        So we can determine if the gather has output format of the input format of the unsqueeze.
        """
        gather_chain = get_all_nodes_in_chain(self._graph, self, [BwdChainNode("Concat"), BwdChainNode("Unsqueeze")])
        axis = self.get_axis()
        return gather_chain and axis == gather_chain[0].get_axis() == gather_chain[1].get_axes_information()[0]

    def get_grouped_reduce_sum_info(self):
        consumed_vertices = []
        slices = list(self.graph.successors(self))
        first_wave_reduce_sums = [next(iter(self.graph.successors(x))) for x in slices]
        second_wave_reduce_sums = [next(iter(self.graph.successors(x))) for x in first_wave_reduce_sums]
        chain = [FwdChainNode("Slice"), FwdChainNode("ReduceSum"), FwdChainNode("ReduceSum"), FwdChainNode("Concat")]
        final_concat = look_for_node(self.graph, self, chain)
        consumed_vertices = slices + first_wave_reduce_sums + second_wave_reduce_sums + [final_concat]
        height_groups = len(slices)
        reduce_axes = [1, 2]
        return height_groups, reduce_axes, consumed_vertices

    def is_matmul_layer(self):
        if self.op != "MatMul":
            return False

        kernel, _ = self.get_kernel(is_conv2d=False)
        return (self.is_ew_op() and kernel is None) or len(kernel.shape) == 4

    def is_tokens_matmul(self):
        if self.op != "MatMul" or len(self._info.input) < 2:
            return False

        pred0 = self._graph.vertices_by_inp_key.get(self._info.input[0])
        pred1 = self._graph.vertices_by_inp_key.get(self._info.input[1])
        if not (pred0 and pred1 and pred0.output_format and pred1.output_format):
            return False

        return pred0.output_format[-2:] == [Dims.CHANNELS, Dims.WIDTH] and pred1.output_format[-2:] == [
            Dims.WIDTH,
            Dims.CHANNELS,
        ]

    def get_matmul_layer_info(self):
        kernel, _ = self.get_kernel(is_conv2d=False)
        pred0 = self._graph.vertices_by_inp_key.get(self._info.input[0])
        pred1 = self._graph.vertices_by_inp_key.get(self._info.input[1])
        pred0_format = pred0.output_format if pred0 else None
        pred1_format = pred1.output_format if pred1 else None

        perms_to_transpose = [[0, 1, 3, 2], [0, 2, 3, 1], [0, 2, 1], [1, 2, 0]]

        transpose_input = False
        if pred0_format and pred1_format:
            if pred0_format != pred1_format:
                transpose_input = True
        elif pred1.op == "Transpose":
            transpose_input = pred1.get_transpose_perm() in perms_to_transpose
        else:
            # maybe the pred of pred is transpose the pred doesn't change the output format (so we need to consider the
            # transpose of the pred of pred)
            preds_of_pred = list(self._graph.predecessors(pred1))
            if (
                len(preds_of_pred) == 1
                and preds_of_pred[0].op == "Transpose"
                and preds_of_pred[0].output_format == pred1_format
            ):
                transpose_input = preds_of_pred[0].get_transpose_perm() in perms_to_transpose

        output_format = None
        matmul_input = self.input
        if pred0_format:
            output_format = pred0_format.copy()
            if pred0_format[-2:] == [Dims.CHANNELS, Dims.WIDTH]:
                matmul_input = matmul_input[::-1]

        elif pred1_format:
            if transpose_input:
                output_format = [*pred1_format[:-2], pred1_format[-1], pred1_format[-2]]
            else:
                output_format = pred1_format.copy()

        groups = 1
        if output_format and Dims.GROUPS in output_format:
            output_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
            groups = output_shape[output_format.index(Dims.GROUPS)]

        return groups, transpose_input, kernel, matmul_input, output_format

    def get_inner_product_matmul_info(self):
        chains = [[BwdChainNode(op="Transpose"), BwdChainNode(op="Reshape")]]
        nodes = get_all_nodes_from_possible_chains(self._graph, self, chains)
        if nodes is not None:
            output_shapes = self.get_output_shapes()
            input_features = self.get_input_shapes()[0][3]
            return nodes, output_shapes, input_features
        return None, None, None

    def is_transpose_before_matmul(self):
        if self.op != "Transpose":
            return False

        perm = self.get_transpose_perm()
        allowed_perm_cond = perm in [[0, 1, 3, 2], [0, 2, 1]]
        heads_first_allowed_perm_cond = self.output_format == [
            Dims.BATCH,
            Dims.GROUPS,
            Dims.CHANNELS,
            Dims.WIDTH,
        ] and perm == [0, 2, 3, 1]
        if not (allowed_perm_cond or heads_first_allowed_perm_cond):
            return False

        chains = [[FwdChainNode(op="MatMul")], [FwdChainNode(op="Mul"), FwdChainNode(op="MatMul")]]
        matmul = get_node_from_possible_chains(self._graph, self, chains)
        if matmul is None:
            return False

        succs = list(self._graph.successors(self))
        channels_first_cond = self.input_format == [Dims.BATCH, Dims.CHANNELS, Dims.WIDTH]
        return (
            len(succs) == 1
            and (matmul.is_matmul_layer() or (matmul.is_conv1x1_matmul() and channels_first_cond))
            and not matmul.is_inner_product_matmul()
        )

    def is_transposed_batch_norm(self):
        if self.op not in [*BN_OPS, "Transpose"]:
            return False

        first_transpose = self
        second_transpose = None
        if self.op in BN_OPS:
            first_transpose = look_for_node(self._graph, self, [BwdChainNode(op="Transpose")])
            if first_transpose is None:
                return False
            second_transpose = look_for_node(self._graph, self, [FwdChainNode(op="Transpose")])

        first_perm = first_transpose.get_transpose_perm()
        if self.op == "Transpose":
            succs = list(self._graph.successors(first_transpose))
            if len(succs) != 1 or succs[0].op not in BN_OPS:
                return False
            bn_op = succs[0]
            second_transpose = look_for_node(self._graph, bn_op, [FwdChainNode(op="Transpose")])

        if second_transpose is None:
            return False

        second_perm = second_transpose.get_attribute_by_name("perm")[0].ints
        return first_perm in [[1, 2, 0], [0, 2, 1]] and second_perm in [[0, 2, 1], [2, 0, 1]]

    def is_transposed_batch_norm_second_transpose(self):
        if self.op != "Transpose":
            return False

        chain = [BwdChainNode("BatchNormalization"), BwdChainNode("Transpose")]
        first_transpose = look_for_node(self._graph, self, chain)
        return bool(first_transpose and first_transpose.is_transposed_batch_norm())

    def get_transposed_bn_info(self):
        bn_node = self if self.op in BN_OPS else look_and_validate(self._graph, self, [FwdChainNode(BN_OPS[0])])
        bn_info, consumed_vertices = bn_node.get_bn_info()
        consumed_vertices.append(bn_node)

        transpose_node = look_and_validate(self._graph, bn_node, [FwdChainNode(op="Transpose")])
        perm = transpose_node.get_attribute_by_name("perm")[0].ints
        if perm == [0, 2, 1]:
            consumed_vertices.append(transpose_node)
        return bn_info, consumed_vertices

    def get_softmax_info(self):
        groups = 1
        onnx_input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        axis = self.get_axis()
        if axis is None:
            # extracting default axis value by opset version
            axis = 1 if self._graph.opset_version < 13 else -1

        additive_mask = self.get_softmax_additive_mask()
        if self.output_format:
            onnx_dim_to_hailo_axis = {Dims.HEIGHT: 1, Dims.WIDTH: 2, Dims.CHANNELS: 3}
            hailo_axis = onnx_dim_to_hailo_axis.get(self.output_format[axis])
            if not hailo_axis:
                raise UnsupportedSoftmaxLayerError(f"Unsupported softmax axis {axis}.")

            if Dims.GROUPS in self.output_format and self.output_format[axis] == Dims.CHANNELS:
                groups = onnx_input_shape[self.output_format.index(Dims.GROUPS)]

            return groups, hailo_axis, additive_mask

        raise UnsupportedSoftmaxLayerError("Unsupported softmax")

    def get_softmax_additive_mask(self, ew_add=None):
        softmax = (
            self
            if not ew_add
            else get_node_from_possible_chains(
                self._graph,
                self,
                [[FwdChainNode(op="Reshape"), FwdChainNode(op="Softmax")], [FwdChainNode(op="Softmax")]],
            )
        )
        ew_add = (
            ew_add
            if ew_add
            else get_node_from_possible_chains(
                self._graph,
                self,
                [[BwdChainNode(op="Reshape"), BwdChainNode(op="Add")], [BwdChainNode(op="Add")]],
            )
        )
        if not ew_add or not softmax or (softmax and softmax.op != "Softmax"):
            return None

        if ew_add:
            const_input_values = ew_add.get_const_input_values()
            if (
                const_input_values is not None
                and len(const_input_values[const_input_values != 0])  # not a dummy add
                and list(const_input_values.shape[-2:]) == self.get_input_shapes(convert_to_nhwc=False)[0][-2:]
                and (  # additive mask contains only zeros and -inf values
                    (const_input_values[const_input_values >= 0] < 1e-10).all()
                    and (const_input_values[const_input_values < 0] < -1e10).all()
                )
            ):
                # the softmax is preceded by an add which performs additive mask
                if ew_add.output_format == [Dims.GROUPS, Dims.WIDTH, Dims.CHANNELS]:
                    const_input_values = np.transpose(const_input_values[None, :], [0, 2, 1, 3])
                    const_input_values = np.reshape(const_input_values, (1, 1, const_input_values.shape[1], -1))

                if (
                    ew_add.output_format == [Dims.BATCH, Dims.GROUPS, Dims.WIDTH, Dims.CHANNELS]
                    and len(const_input_values.shape) == 2
                ):
                    const_input_values = const_input_values[None, None, :, :]
                return const_input_values[0]

        return None

    def is_additive_mask_for_softmax(self):
        return self.get_softmax_additive_mask(ew_add=self) is not None

    def get_d2s_block_size(self):
        return self.get_attribute_by_name("blocksize")[0].i

    def is_dcr_depth_to_space(self):
        return not self.get_attribute_by_name("mode") or "DCR" in str(self.get_attribute_by_name("mode")[0].s)

    def is_null_flatten_reshape(self):
        possible_bn_chains = [
            [FwdChainNode(op="BatchNormalization"), FwdChainNode(op="Relu"), FwdChainNode(op=dense_op)]
            for dense_op in DENSE_OPS
        ]
        possible_dense_chains = [[FwdChainNode(op=dense_op)] for dense_op in DENSE_OPS]
        possible_chains = possible_bn_chains + possible_dense_chains
        is_followed_by_fc = get_node_from_possible_chains(self._graph, self, possible_chains) is not None
        is_flatten_reshape, _ = self.is_flatten_reshape()
        return is_flatten_reshape and is_followed_by_fc

    def is_flatten_reshape(self):
        output_shapes = self.get_output_shapes()
        if output_shapes:
            output_shape = output_shapes[0]
            input_shape = self.get_input_shapes()[0]
            if len(output_shape) == 2 and len(input_shape) == 4:
                return output_shape[1] == input_shape[1] * input_shape[2] * input_shape[3], self.get_all_shape_nodes()

            return False, []

        # The goal here is to recognize a flatten operation implemented with dynamic shape op from predecessor.
        # In Pytorch it would look like view(x.size(0), -1) for example.
        consumed_vertices = []
        concat_node = look_for_node(self._graph, self, [BwdChainNode(op="Concat")])
        if not concat_node:
            reshape_shape = []
            const_node = look_for_node(self._graph, self, [BwdChainNode(op="Constant")])
            if const_node:
                reshape_shape = const_node.parse_raw_data(cast_to_int=True).flatten().tolist()
                consumed_vertices.append(const_node)
            elif self.graph.values_by_vertex_name.get(self.name):
                reshape_shape = self.graph.values_by_vertex_name[self.name][self._info.input[-1]]
            return len(reshape_shape) == 2, consumed_vertices

        gather_chain = None
        concat_chain = None
        consumed_vertices.append(concat_node)
        preds_unsqueezes = sorted(
            [x for x in self._graph.predecessors(concat_node) if x.op == "Unsqueeze"],
            key=lambda x: x._hash,
        )
        if preds_unsqueezes:
            gather_chain = get_all_nodes_in_chain(
                self._graph,
                preds_unsqueezes[0],
                [BwdChainNode(op="Gather"), BwdChainNode(op="Shape")],
            )
            if len(preds_unsqueezes) == 2:
                concat_chain = get_all_nodes_in_chain(
                    self._graph,
                    preds_unsqueezes[1],
                    [BwdChainNode(op="Cast"), BwdChainNode(op="Constant")],
                )

        valid_gather, consumed_gather_nodes = self.validate_flatten_gather_axis(gather_chain)
        valid_concat, consumed_concat_nodes = self.validate_flatten_concat_inputs(concat_node, concat_chain)

        consumed_vertices.extend(preds_unsqueezes)
        consumed_vertices.extend(consumed_gather_nodes)
        consumed_vertices.extend(consumed_concat_nodes)
        return valid_gather and valid_concat, consumed_vertices

    def validate_flatten_gather_axis(self, gather_nodes):
        if not gather_nodes:
            return False, []
        gather_node = gather_nodes[0]
        axis = gather_node.get_axis()
        valid_axis = axis == 0 if axis is not None else False
        index_nodes = get_all_nodes_from_possible_chains(
            self._graph,
            gather_node,
            [[BwdChainNode(op="Constant")], [BwdChainNode(op="Cast"), BwdChainNode(op="Constant")]],
        )

        valid_index = index_nodes[-1].parse_raw_data(cast_to_int=True).tolist() == 0 if index_nodes else False
        consumed_vertices = gather_nodes + index_nodes
        return valid_axis and valid_index, consumed_vertices

    def validate_flatten_concat_inputs(self, concat_node, concat_preds):
        vertex_params = self._graph.values_by_vertex_name.get(concat_node.name, None)
        if vertex_params:
            init_vals = list(vertex_params.values())
            return init_vals[0] == [-1], []

        if concat_preds:
            return concat_preds[-1].parse_raw_data(cast_to_int=True).tolist() == -1, concat_preds

        return False, []

    def get_resize_upscale_factors(self):
        if self.op == "Upsample":
            return self._get_upsample_upscale_factors()

        resize_vars = self.graph.values_by_vertex_name.get(self.name, None)
        if self.op == "Resize":
            consumed_vertices = []
            scales_value = []
            var_index = RESIZE_INPUT_ORDER.index("scales")
            scales_input = self._info.input[var_index]
            if resize_vars:
                scales_value = resize_vars[scales_input]
            elif scales_input in self._graph.vertices_by_inp_key:
                scales_const = self._graph.vertices_by_inp_key[scales_input]
                try:
                    scales_value = scales_const.parse_raw_data()
                # edge case when do_constant_folding=False, should_simplify=False
                except TypeError:
                    const_preds = scales_const.get_constant_vertices([scales_const])
                    scales_value = np.concatenate([const_preds[0].parse_raw_data(), const_preds[1].parse_raw_data()])
                consumed_vertices.append(scales_const)

            if len(scales_value) > 0:
                h_ratio, w_ratio = self._get_height_width_ratios_from_scales(scales_value)
                return [h_ratio, w_ratio], consumed_vertices

        cast_node = look_for_node(self._graph, self, [BwdChainNode(op="Concat"), BwdChainNode(op="Cast")])
        if cast_node:
            return self._get_upscale_factors_from_concat_chain(cast_node)

        return None, []

    def _get_upsample_upscale_factors(self):
        consumed_vertices = []
        scales_const = look_for_node(self._graph, self, [BwdChainNode(op="Constant")])
        scales_attr = self.get_attribute_by_name("scales")
        scales_vars = self.graph.values_by_vertex_name.get(self.name, None)
        if scales_const:
            consumed_vertices.append(scales_const)
            scales = scales_const.parse_raw_data(cast_to_int=True).flatten().tolist()
            h_ratio, w_ratio = self._get_height_width_ratios_from_scales(scales)
        elif scales_attr:
            h_ratio, w_ratio = self._get_height_width_ratios_from_scales(scales_attr[0].floats)
        elif scales_vars:
            var_index = UPSAMPLE_INPUT_ORDER.index("scales")
            scales_value = scales_vars[self._info.input[var_index]]
            if len(scales_value) > 0:
                h_ratio, w_ratio = self._get_height_width_ratios_from_scales(scales_value)
        else:
            height_attr = self.get_attribute_by_name("height_scale")
            width_attr = self.get_attribute_by_name("width_scale")
            h_ratio = height_attr[0].f if height_attr else 1
            w_ratio = width_attr[0].f if width_attr else 1
        return [h_ratio, w_ratio], consumed_vertices

    def _get_height_width_ratios_from_scales(self, scales):
        if len(scales) == 4:
            h_ratio, w_ratio = scales[2:4]
        elif len(scales) == 3:
            h_ratio, w_ratio = 1, scales[2]
        else:
            h_ratio, w_ratio = 1, 1
        return h_ratio, w_ratio

    def _get_upscale_factors_from_concat_chain(self, cast_node):
        inner_concat_node = look_for_node(self._graph, cast_node, [BwdChainNode(op="Concat")])
        if inner_concat_node:
            concat_preds = {x.name: x for x in self._graph.predecessors(inner_concat_node) if x.op == "Unsqueeze"}
            sorted_unsqueezes = [concat_preds[name] for name in sorted(concat_preds.keys())]

            factor_chain = [
                [
                    BwdChainNode(op="Floor"),
                    BwdChainNode(op="Cast"),
                    BwdChainNode(op="Mul"),
                    BwdChainNode(op="Constant"),
                ],
                [BwdChainNode(op="Mul"), BwdChainNode(op="Constant")],
            ]

            height_factor_node = get_node_from_possible_chains(self._graph, sorted_unsqueezes[0], factor_chain)
            width_factor_node = get_node_from_possible_chains(self._graph, sorted_unsqueezes[1], factor_chain)

            h_ratio = height_factor_node.parse_raw_data(cast_to_int=True).tolist() if height_factor_node else 1
            w_ratio = width_factor_node.parse_raw_data(cast_to_int=True).tolist() if width_factor_node else 1
            consumed_vertices = self.consume_resize_upscale_vertices(sorted_unsqueezes)
        else:
            factor_node = look_for_node(self._graph, cast_node, [BwdChainNode(op="Mul")])
            if not factor_node:
                return None, []

            h_ratio, w_ratio = self._graph.values_by_vertex_name[factor_node.name]["kernel"].tolist()
            consumed_vertices = [cast_node, factor_node]
            concat_shape_nodes = get_all_nodes_in_chain(
                self._graph,
                self,
                [BwdChainNode(op="Concat"), BwdChainNode(op="Slice"), BwdChainNode(op="Shape")],
            )
            if concat_shape_nodes is not None:
                consumed_vertices += concat_shape_nodes

            slice_shape_nodes = get_all_nodes_in_chain(
                self._graph,
                factor_node,
                [
                    BwdChainNode(op="Slice"),
                    BwdChainNode(op="Cast"),
                    BwdChainNode(op="Gather"),
                    BwdChainNode(op="Shape"),
                ],
            )
            if slice_shape_nodes is not None:
                consumed_vertices += slice_shape_nodes

        return [h_ratio, w_ratio], consumed_vertices

    def is_unsqueeze_before_conv3d(self):
        if self.op != "Unsqueeze":
            return False

        possible_chains = [[FwdChainNode("Conv")], [FwdChainNode("Concat"), FwdChainNode("Conv")]]
        nodes = get_all_nodes_from_possible_chains(self._graph, self, possible_chains)
        if nodes is None:
            return False

        rank5_cond = len(self.get_output_shapes(convert_to_nhwc=False)[0]) == 5
        return rank5_cond and nodes[-1].is_conv3d()

    def is_matmul_over_groups(self):
        if self.op != "MatMul" or not self.output_format:
            return False

        if len(self.output_format) == 4 and self.output_format[-2:] == [Dims.CHANNELS, Dims.GROUPS]:
            kernel, _ = self.get_kernel(is_conv2d=False)
            if kernel is not None:
                output_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
                return len(kernel.shape) == 2 and kernel.shape[0] == kernel.shape[0] == output_shape[-1]

        return False

    def is_conv_over_groups(self):
        if self.op not in CONV2D_OPS:
            return False

        if not self.output_format:
            return False

        if self.output_format[1] != Dims.GROUPS:
            return False

        vertex_kernel, _ = self.get_kernel()
        if vertex_kernel is None:
            return False

        return vertex_kernel.shape[2] == vertex_kernel.shape[3] == 1

    def get_conv_over_groups_info(self):
        vertex_kernel, _ = self.get_kernel()
        if len(vertex_kernel.shape) == 4:
            kernel = np.transpose(vertex_kernel, [2, 3, 1, 0])
        elif len(vertex_kernel.shape) == 2:
            kernel = np.reshape(vertex_kernel, [1, 1, *vertex_kernel.shape])
        bias, _ = self.get_bias()

        features = self.get_output_shapes(convert_to_nhwc=False)[0][self.output_format.index(Dims.CHANNELS)]
        new_kernel = np.zeros((kernel.shape[2] * features, kernel.shape[3] * features))

        # Fill the diagonal blocks with the original kernel's values
        for i in range(kernel.shape[2]):
            for j in range(kernel.shape[3]):
                new_kernel[i * features : (i + 1) * features, j * features : (j + 1) * features] = (
                    np.eye(features) * kernel[0][0][i][j]
                )

        kernel = new_kernel.reshape(1, 1, new_kernel.shape[0], new_kernel.shape[1])

        if bias is not None:
            bias = np.repeat(bias, features)

        return kernel, bias

    def is_conv3d(self):
        return self.op == "Conv" and (
            len(self.get_kernel_shape()) == 3 or (self.output_format and self.output_format[2] == Dims.GROUPS)
        )

    def is_squeeze_after_conv3d(self):
        if self.op != "Squeeze":
            return False

        pred = next(iter(self.graph.predecessors(self)))
        if pred.is_conv3d():
            return True
        return pred.output_format == [Dims.BATCH, Dims.CHANNELS, Dims.GROUPS, Dims.HEIGHT, Dims.WIDTH]

    def is_concat3d(self):
        if self.op != "Concat":
            return False

        output_shapes = self.get_output_shapes(convert_to_nhwc=False)
        if not output_shapes:
            return False

        return len(output_shapes[0]) == 5

    def consume_resize_upscale_vertices(self, unsqueeze_section_vertices):
        concat_shapes_chain = [BwdChainNode(op="Concat"), BwdChainNode(op="Cast"), BwdChainNode(op="Concat")]

        slice_section_chain = [BwdChainNode(op="Concat"), BwdChainNode(op="Slice"), BwdChainNode(op="Shape")]

        possible_unsqueeze_section_chains = [
            [BwdChainNode(op="Mul"), BwdChainNode(op="Gather"), BwdChainNode(op="Shape")],
            [
                BwdChainNode(op="Floor"),
                BwdChainNode(op="Cast"),
                BwdChainNode(op="Mul"),
                BwdChainNode(op="Cast"),
                BwdChainNode(op="Gather"),
                BwdChainNode(op="Shape"),
            ],
        ]

        no_scale_chain = [
            BwdChainNode(op="Floor"),
            BwdChainNode(op="Cast"),
            BwdChainNode(op="Cast"),
            BwdChainNode(op="Gather"),
            BwdChainNode(op="Shape"),
        ]

        cluster_union_ops = [self, *unsqueeze_section_vertices]
        concat_shapes_nodes = get_all_nodes_in_chain(self._graph, self, concat_shapes_chain)
        slice_section_nodes = get_all_nodes_in_chain(self._graph, self, slice_section_chain)

        if concat_shapes_nodes is None or slice_section_nodes is None:
            raise UnsupportedResizeLayerError(f"Failed to determine resize layer {self.name} scale factors.")

        cluster_union_ops += concat_shapes_nodes
        cluster_union_ops += slice_section_nodes

        for vertex in unsqueeze_section_vertices:
            # If scale is 1, fallback chain is searched
            unsqueeze_section_nodes = get_all_nodes_from_possible_chains(
                self._graph,
                vertex,
                possible_unsqueeze_section_chains,
            )
            no_scale_nodes = get_all_nodes_in_chain(self._graph, vertex, no_scale_chain)
            if unsqueeze_section_nodes is not None:
                cluster_union_ops += unsqueeze_section_nodes
            elif no_scale_nodes is not None:
                cluster_union_ops += no_scale_nodes
            else:
                raise UnsupportedResizeLayerError(f"Failed to determine resize layer {self.name} scale factors.")

        cluster_union_ops += self.get_constant_vertices(cluster_union_ops)
        return cluster_union_ops

    def consume_upsample_nodes(self):
        div_chain = [BwdChainNode(op="Concat"), BwdChainNode(op="Div")]
        cast_shape_chain = [BwdChainNode(op="Slice"), BwdChainNode(op="Shape")]
        second_concat_chain = [BwdChainNode(op="Concat")]
        stem_shapes_chain = [BwdChainNode(op="Gather"), BwdChainNode(op="Shape")]

        cluster_union_ops = [self]
        div = look_for_node(self._graph, self, div_chain)
        if div:
            div_nodes = get_all_nodes_in_chain(self._graph, self, div_chain)
            if div_nodes is not None:
                cluster_union_ops += div_nodes

            cast_preds = [x for x in self._graph.predecessors(div) if x.op == "Cast"]
            cluster_union_ops += cast_preds
            for cast in cast_preds:
                shape_slice = look_for_node(self._graph, cast, cast_shape_chain)
                second_concat = look_for_node(self._graph, cast, second_concat_chain)
                if shape_slice:
                    cast_shape_nodes = get_all_nodes_in_chain(self._graph, cast, cast_shape_chain)
                    if cast_shape_nodes is not None:
                        cluster_union_ops += cast_shape_nodes
                elif second_concat:
                    cluster_union_ops += [second_concat]
                    unsqueeze_preds = [x for x in self._graph.predecessors(second_concat) if x.op == "Unsqueeze"]
                    cluster_union_ops += unsqueeze_preds
                    for unsqueeze_node in unsqueeze_preds:
                        if look_for_node(self._graph, unsqueeze_node, stem_shapes_chain) is not None:
                            stem_shapes_nodes = get_all_nodes_in_chain(self._graph, unsqueeze_node, stem_shapes_chain)
                            if stem_shapes_nodes is not None:
                                cluster_union_ops += stem_shapes_nodes

        return cluster_union_ops

    def get_keras_unsqueeze_resize_nearest_block_info(self):
        resize_sizes = None
        keras_resize_chain = [
            FwdChainNode(op="Concat"),
            FwdChainNode(op="Unsqueeze"),
            FwdChainNode(op="Concat"),
            FwdChainNode(op="Transpose"),
            FwdChainNode(op="Reshape"),
            FwdChainNode(op="Transpose"),
        ]
        keras_nodes = get_all_nodes_in_chain(self._graph, self, keras_resize_chain)
        if keras_nodes:
            # require specific permutation on the transposes that implement the resize layer
            transposes = [keras_nodes[-1], keras_nodes[-3]]
            perms = [x.get_transpose_perm() for x in transposes]
            if perms[0] != [0, 3, 1, 2] or perms[1] != [0, 3, 4, 5, 1, 2]:
                reshape_node = keras_nodes[-2]
                if self.graph.values_by_vertex_name.get(reshape_node.name, None):
                    resize_sizes = next(
                        list(val) for val in self.graph.values_by_vertex_name[reshape_node.name].values()
                    )
                    resize_sizes = resize_sizes[1:3]
                    return keras_nodes, resize_sizes

        return None, None

    def get_torch_unsqueeze_resize_nearest_block_info(self):
        resize_sizes = None
        torch_resize_chain = [
            FwdChainNode(op="Expand"),
            FwdChainNode(op="Reshape"),
        ]
        torch_nodes = get_all_nodes_in_chain(self._graph, self, torch_resize_chain)
        if torch_nodes:
            input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
            expand_output_shape = torch_nodes[0].get_output_shapes(convert_to_nhwc=False)[0]
            reshape_output_shape = torch_nodes[1].get_output_shapes(convert_to_nhwc=False)[0]
            if (
                len(expand_output_shape) == 5
                and len(reshape_output_shape) == 4
                and expand_output_shape[2] * input_shape[1] == reshape_output_shape[1]
            ):
                reshape_node = torch_nodes[1]
                if self.graph.values_by_vertex_name.get(reshape_node.name, None):
                    resize_sizes = next(
                        list(val) for val in self.graph.values_by_vertex_name[reshape_node.name].values()
                    )
                    return torch_nodes, resize_sizes

        return None, None

    def is_unsqueeze_resize_nearest(self):
        if self.op != "Unsqueeze":
            return False

        # multiple implementations from torch and keras are supported
        chains = [
            [
                BwdChainNode(op="Einsum"),
                FwdChainNode(op="Unsqueeze"),
                FwdChainNode(op="Add"),
                BwdChainNode(op="Reshape"),
            ],
        ]
        reshape = get_node_from_possible_chains(self._graph, self, chains)
        if reshape is not None:
            return True

        resize_vertices, _ = self.get_keras_unsqueeze_resize_nearest_block_info()
        if resize_vertices is None:
            resize_vertices, _ = self.get_torch_unsqueeze_resize_nearest_block_info()
            if resize_vertices is None:
                return False

        default_logger().debug(
            "Identified a block that implements a resize nearest layer, structure spans from "
            f"{self.name} to {resize_vertices[-1].name}",
        )
        return True

    def is_torch_resize_nearest_reshape(self):
        chain = [
            BwdChainNode(op="Expand"),
            BwdChainNode(op="Unsqueeze"),
        ]
        unsqueeze = look_for_node(self.graph, self, chain)
        return unsqueeze and unsqueeze.is_unsqueeze_resize_nearest()

    def is_null_transpose_near_torch_tile(self):
        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        next_transpose = look_for_node(
            self._graph,
            self,
            [FwdChainNode(op="Mul"), FwdChainNode(op="Add"), FwdChainNode(op="Transpose")],
        )
        if next_transpose is not None:
            output_shape = next_transpose.get_output_shapes(convert_to_nhwc=False)[0]
            if all(x == y for x, y in zip(input_shape, output_shape)):
                return True

        return False

    def is_torch_tile_reduce_max(self):
        possible_chain = [BwdChainNode(op="Reshape"), BwdChainNode(op="Transpose"), BwdChainNode(op="Tile")]
        tile = look_for_node(self._graph, self, possible_chain)
        if tile is not None:
            axes = list(self.get_attribute_by_name("axes")[0].ints)
            if axes == [1, 2]:
                return True
        return False

    def is_reshape_expand_resize_nearest(self):
        possible_chains = [[FwdChainNode(op="Expand"), FwdChainNode(op="Reshape")]]
        nodes = get_all_nodes_from_possible_chains(self._graph, self, possible_chains)
        if nodes is None:
            return False

        expand_node = nodes[0]
        reshape_node = nodes[1]
        input_shape = self.get_input_shapes()[0]
        output_shape = self.get_output_shapes()[0]

        self_shape_cond = len(output_shape) == 6
        reshape_ones_cond = (
            input_shape[3] == output_shape[1]
            and output_shape[-1] == 1
            and output_shape[-3] == 1
            and np.prod(input_shape) == np.prod(output_shape)
        )
        expand_shape_cond = len(expand_node.get_output_shapes()[0]) == 6
        output_shape_cond = len(reshape_node.get_output_shapes()[0]) == 4

        return self_shape_cond and reshape_ones_cond and expand_shape_cond and output_shape_cond

    def is_torch_tile_resize_nearest(self):
        if self.op in TILE_OPS:
            variant1_chain = [
                FwdChainNode(op="Transpose"),
                FwdChainNode(op="Reshape"),
                FwdChainNode(op="Add"),
                FwdChainNode(op="Transpose"),
            ]
            variant2_chain = [FwdChainNode(op="Transpose"), FwdChainNode(op="Reshape"), FwdChainNode(op="Transpose")]
        elif self.op == "Transpose":
            variant1_chain = [
                BwdChainNode(op="Add"),
                BwdChainNode(op="Reshape"),
                BwdChainNode(op="Transpose"),
                BwdChainNode(op="Tile"),
            ]
            variant2_chain = [BwdChainNode(op="Reshape"), BwdChainNode(op="Transpose"), BwdChainNode(op="Tile")]
        else:
            return False

        possible_chains = [variant1_chain, variant2_chain]
        cluster = get_all_nodes_from_possible_chains(self._graph, self, possible_chains)
        return cluster is not None

    def get_torch_tile_resize_sizes(self):
        sizes = None
        consumed_vertices = []
        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]

        variant1_chain = [
            FwdChainNode(op="Transpose"),
            FwdChainNode(op="Reshape"),
            FwdChainNode(op="Add"),
            FwdChainNode(op="Transpose"),
        ]
        variant2_chain = [FwdChainNode(op="Transpose"), FwdChainNode(op="Reshape"), FwdChainNode(op="Transpose")]

        possible_chains = [variant1_chain, variant2_chain]
        cluster = get_all_nodes_from_possible_chains(self._graph, self, possible_chains)

        if cluster is not None:
            if any(x.op == "Add" for x in cluster):
                consumed_vertices.extend(cluster[0:2])
            else:
                consumed_vertices.extend(cluster)

            scales = next(list(val) for val in self.graph.values_by_vertex_name[self.name].values())
            sizes = [scales[1] * input_shape[2], scales[3] * input_shape[3]]

        return sizes, consumed_vertices

    def get_torch_reshape_expand_sizes(self):
        sizes = None
        consumed_vertices = []

        possible_chains = [[FwdChainNode(op="Expand"), FwdChainNode(op="Reshape")]]
        nodes = get_all_nodes_from_possible_chains(self._graph, self, possible_chains)
        if nodes is not None:
            consumed_vertices.extend(nodes)
            output_shape = nodes[1].get_output_shapes()[0]
            sizes = [output_shape[1], output_shape[2]]

        return sizes, consumed_vertices

    def get_resize_method(self):
        # edge case, from converted keras models
        if self.op in ["Unsqueeze", "Tile", "Reshape"]:
            return ResizeMethod.nearest_neighbor

        attr_val = self.get_attribute_by_name("mode")[0].s.decode()
        if attr_val == "linear":
            return ResizeMethod.bilinear
        if attr_val == "nearest":
            return ResizeMethod.nearest_neighbor

        raise UnsupportedResizeLayerError(f"Unsupported resize method {attr_val}")

    def get_resize_bilinear_pixels_mode(self):
        attr = self.get_attribute_by_name("coordinate_transformation_mode")
        transform_str = str(attr[0].s.decode() if attr else "half_pixel")  # default value in onnx
        if transform_str == "align_corners":
            return ResizeBilinearPixelsMode.align_corners
        if "half_pixel" in transform_str:
            return ResizeBilinearPixelsMode.half_pixels

        return ResizeBilinearPixelsMode.disabled

    def get_resize_const_sizes(self):
        resize_sizes = None
        consumed_vertices = []

        concat_node = look_for_node(self._graph, self, [BwdChainNode(op="Concat")])
        node_to_check = concat_node if concat_node else self
        preds_to_check = list(self._graph.predecessors(node_to_check))

        if self.graph.values_by_vertex_name.get(node_to_check.name, None):
            var_index = RESIZE_INPUT_ORDER.index("sizes") if node_to_check == self else -1
            resize_sizes = list(
                self.graph.values_by_vertex_name[node_to_check.name][node_to_check._info.input[var_index]],
            )
            if len(resize_sizes) == 1:  # 1D
                resize_sizes = [1, *resize_sizes]
            if concat_node:
                consumed_vertices = self.consume_resize_const_sizes_vertices(resize_shapes_are_const=False)
        elif "Cast" in [x.op for x in preds_to_check]:
            cast = preds_to_check[[x.op for x in preds_to_check].index("Cast")]
            param_input = cast._info.input[0]
            if param_input in self._graph.vertices_by_inp_key:
                # Try to get the values from constant inputs
                const = self._graph.vertices_by_inp_key[param_input]
                resize_sizes = list(const.parse_raw_data())
            if len(resize_sizes) == 1:  # 1D
                resize_sizes = [1, *resize_sizes]
            if concat_node:
                consumed_vertices = self.consume_resize_const_sizes_vertices(resize_shapes_are_const=False)
        elif concat_node:
            possible_shapes_chains = [
                [BwdChainNode(op="Constant")],
                [BwdChainNode(op="Cast"), BwdChainNode(op="Constant")],
            ]
            shapes_node = get_node_from_possible_chains(self._graph, node_to_check, possible_shapes_chains)
            if shapes_node:
                resize_sizes = shapes_node.parse_raw_data(cast_to_int=True).flatten().tolist()
                consumed_vertices = self.consume_resize_const_sizes_vertices(resize_shapes_are_const=True)
        else:
            const_inputs = [x for x in self._graph.predecessors(node_to_check) if x.op == "Constant"]
            # splits x by ':' that stands alone and not part of '::' - onnx::resize:25 -> [onnx::resize, 25]
            split_inputs = [re.split("(?<!:):(?!:)", x)[0] for x in self.input]
            sorted_inputs = sorted([x.name for x in const_inputs], key=lambda x: split_inputs.index(x))
            resize_sizes = (
                self.graph.get_vertex_by_name(sorted_inputs[-1]).parse_raw_data(cast_to_int=True).flatten().tolist()
            )
            if len(resize_sizes) == 3:  # 1D
                resize_sizes.insert(2, 1)
            consumed_vertices.extend(const_inputs)

        return resize_sizes, consumed_vertices

    def get_resize_dynamic_sizes(self):
        resize_vertices = [None, None]
        consumed_vertices = []
        chains = [
            [BwdChainNode(op="Concat"), BwdChainNode(op="Slice"), BwdChainNode(op="Shape")],
            [
                BwdChainNode(op="Concat"),
                BwdChainNode(op="Cast"),
                BwdChainNode(op="Concat"),
                BwdChainNode(op="Unsqueeze"),
                BwdChainNode(op="Gather"),
                BwdChainNode(op="Shape"),
            ],
        ]
        for chain in chains:
            if not look_for_node(self._graph, self, chain):
                return None, None

        concat_node = look_for_node(self._graph, self, [BwdChainNode(op="Concat")])
        consumed_vertices.append(concat_node)

        concat_preds = self._graph.predecessors(concat_node)
        for pred in concat_preds:
            if pred.op == "Slice":
                shape_node = look_for_node(self._graph, pred, [BwdChainNode(op="Shape")])
                consumed_vertices.extend([pred, shape_node])
            elif pred.op == "Cast":
                inner_concat_node = look_for_node(self._graph, pred, [BwdChainNode(op="Concat")])
                consumed_vertices.extend([pred, inner_concat_node])
                for inner_pred in self._graph.predecessors(inner_concat_node):
                    shape_node = look_for_node(
                        self._graph,
                        inner_pred,
                        [BwdChainNode(op="Gather"), BwdChainNode(op="Shape")],
                    )
                    gather_node = look_for_node(self._graph, inner_pred, [BwdChainNode(op="Gather")])
                    consumed_vertices.extend([inner_pred, shape_node, gather_node])
                    gather_const_node = look_for_node(self._graph, gather_node, [BwdChainNode(op="Constant")])
                    index = gather_const_node.parse_raw_data(cast_to_int=True).tolist() - 2
                    resize_vertices[index] = next(iter(self._graph.predecessors(shape_node)))

        return resize_vertices, consumed_vertices

    def consume_resize_const_sizes_vertices(self, resize_shapes_are_const):
        slice_shape_chain = [BwdChainNode(op="Concat"), BwdChainNode(op="Slice"), BwdChainNode(op="Shape")]

        cluster_union_ops = [self]
        slice_shape_nodes = get_all_nodes_in_chain(self._graph, self, slice_shape_chain)
        if slice_shape_nodes is not None:
            cluster_union_ops += slice_shape_nodes

        if resize_shapes_are_const:
            possible_chains = [
                [BwdChainNode(op="Concat"), BwdChainNode(op="Cast")],
                [BwdChainNode(op="Concat"), BwdChainNode(op="Constant")],
            ]
            resize_nodes = get_all_nodes_from_possible_chains(self._graph, self, possible_chains)
            if resize_nodes is not None:
                cluster_union_ops += resize_nodes

        cluster_union_ops += self.get_constant_vertices(cluster_union_ops)
        return cluster_union_ops

    def get_constant_vertices(self, vertices):
        result = []
        for vertex in vertices:
            const_preds = [x for x in self._graph.predecessors(vertex) if x.op == "Constant"]
            if const_preds:
                result.extend(const_preds)
        return result

    def get_slices_args(self):
        slice_args = {}
        input_indices = list(self._info.input)
        const_vertices = [x for x in self._graph.predecessors(self) if x.op == "Constant"]
        var_initializers = self.graph.values_by_vertex_name.get(self.name, None)
        for input_index, input_key in enumerate(input_indices):
            if input_index == 0:
                continue
            input_name = SLICE_INPUT_ORDER[input_index]
            if input_key + "_value" in self._graph.output_shapes:
                slice_args[input_name] = self._graph.output_shapes[input_key + "_value"]
            elif input_key in var_initializers:
                slice_args[input_name] = var_initializers[input_key]
            elif any(input_key in pred.output for pred in const_vertices):
                pred = next(x for x in const_vertices if input_key in x.output)
                slice_args[input_name] = pred.parse_raw_data(cast_to_int=True).flatten().tolist()

        for input_name in SLICE_INPUT_ORDER:
            if self.get_attribute_by_name(input_name):
                slice_args[input_name] = list(self.get_attribute_by_name(input_name)[0].ints)

        if not all(name in slice_args for name in SLICE_ATTRS_ORDER):
            raise UnsupportedSliceLayerError(
                f"Unexpected Slice operation {self.name}, has no attributes/constants to define its "
                f"parameters axes, starts, ends.",
            )

        if "axes" not in slice_args:
            slice_args["axes"] = list(
                range(len(self.get_output_shapes(convert_to_nhwc=False)[0])),
            )

        if "steps" not in slice_args:
            slice_args["steps"] = [1] * len(slice_args["starts"])
        elif slice_args["steps"].shape == ():
            slice_args["steps"] = [slice_args["steps"].tolist()]
        else:
            slice_args["steps"] = slice_args["steps"].tolist()

        if isinstance(slice_args["axes"], np.ndarray):
            slice_args["axes"] = slice_args["axes"].tolist()
        if not isinstance(slice_args["axes"], list):
            slice_args["axes"] = [slice_args["axes"]]

        return slice_args, const_vertices

    def is_null_operation(self):
        if self.op not in OPTIONAL_NULL_OPS:
            return False

        if (
            self.op in MATH_OPS
            or self.is_null_add()
            or self.is_null_reshape()[0]
            or self.op in DROP_OPS
            or (self.op in PAD_OPS and self.is_null_padding())
            or (self.op in ["Reshape", "Flatten"] and self.is_null_flatten_reshape())
            or (self.op in ["Reshape", "Transpose"] and self.is_channel_shuffle_null_ops())
            or (self.op in GATHER_OPS and not self.is_supported_gather())
            or (self.op == "Expand" and self.is_broadcast_expand())
            or (self.op == "Squeeze" and (self.is_null_squeeze() or self.is_squeeze_after_conv3d()))
            or (self.op == "Unsqueeze" and self.is_null_unsqueeze())
            or (self.op == "Transpose" and self.is_null_transpose())
            or (self.op == "Slice" and self.is_null_slice())
            or (self.op == "Tile" and self.is_null_tile())
            or (self.op in SPLIT_OPS and self.is_null_split_over_groups())
            or (self.op == "ReduceMean" and self.is_null_reduce_mean())
        ):
            return True

        if self.is_successive_unsqueeze_flat_to_frame():
            return True

        if self.op == "Clip":
            if self.is_null_clip():
                return True

            if self.is_clip_to_positive():
                preds = list(self.graph.predecessors(self))
                if len(preds) == 1 and preds[0].is_non_negative_result():
                    return True

        return False

    def is_null_split_over_groups(self):
        succs = list(self._graph.successors(self))
        if len(succs) != 1 or succs[0].op != "Concat":
            return False

        concat = succs[0]
        concat_preds = set(self._graph.predecessors(concat))
        if not self.output_format or not concat.output_format or concat_preds != {self}:
            return False

        split_axis = self.get_axis()
        concat_axis = concat.get_axis()
        return self.output_format[split_axis] == Dims.GROUPS and concat.output_format[concat_axis] == Dims.CHANNELS

    def is_null_reduce_mean(self):
        input_shape = self.get_input_shapes()[0]
        axes_info = self.get_axes_information()
        axes = self._convert_axes_to_nhwc(axes_info)
        return all(input_shape[dim] == 1 for dim in axes)

    def is_null_tile(self):
        repeats, _ = self.get_tile_repeats()
        if repeats is None:
            return False
        return all(repeat == 1 for repeat in repeats)

    def is_null_slice(self):
        if self.op not in SLICE_OPS:
            return False

        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        output_shape = self.get_output_shapes(convert_to_nhwc=False)[0]

        if input_shape != output_shape:
            return False

        slices_args, _ = self.get_slices_args()
        return all(step == 1 for step in slices_args["steps"])

    def is_empty_slice(self):
        if self.op not in SLICE_OPS:
            return False

        return 0 in self.get_output_shapes(convert_to_nhwc=False)[0]

    def get_null_vertices(self):
        # original names order depends on whether bwd/fwd chain.
        # for bwd chain, the original names order should be reversed.
        reverse_insertion = False

        if self.is_null_add():
            return [], reverse_insertion

        is_null_reshape, consumed_vertices = self.is_null_reshape()
        if is_null_reshape:
            return consumed_vertices, reverse_insertion

        if self.is_flat_to_frames_reshape():
            return self.get_flat_to_frames_reshape_info()[-1], reverse_insertion

        if self.op == "Clip":
            return self.get_clip_info()[-1], reverse_insertion

        is_flatten, consumed_vertices = self.is_flatten_reshape()
        if is_flatten:
            reverse_insertion = True
            flatten_chain = self.get_flatten_chain(fwd=True)
            return consumed_vertices + flatten_chain, reverse_insertion

        chains = [[FwdChainNode(op="Squeeze")], [FwdChainNode(op="Unsqueeze")]]
        if self.op in ["Squeeze", "Unsqueeze"]:
            next_node = get_node_from_possible_chains(self._graph, self, chains)
            if next_node:
                return [next_node], reverse_insertion

        if self.op in SPLIT_OPS and self.is_null_split_over_groups():
            return list(self._graph.successors(self)), reverse_insertion

        return [], reverse_insertion

    def is_non_negative_result(self):
        if self.op in REDUCE_L2_OPS:
            return True

        return bool(self.op in MUL_OPS + POW_OPS and self.is_square())

    def validate_reduce_mean_as_pooling_layer(self):
        axes = self.get_attribute_by_name("axes")
        axes = list(axes[0].ints) if axes else self.get_initializer_or_constant_value(MIN_INPUT_ORDER)
        pred = next(iter(self._graph.predecessors(self)))
        dims = [pred.output_format[axis] for axis in axes]

        if Dims.BATCH in dims:
            default_logger().warning(
                f"Ignoring axis 0 on reduce mean layer {self.name}, reducing will work correctly "
                "during inference on batch=1 only. It's recommended to run evaluation on the "
                "full-precision model before optimization to verify accuracy is still good "
                "enough.",
            )

        if not set(dims).issubset({Dims.HEIGHT, Dims.WIDTH, Dims.BATCH}):
            raise UnsupportedReduceMeanLayerError(f"Reduce mean layer {self.name} has unsupported axes {axes}")

    def is_valid_reduce_max_min(self):
        if not self.get_keepdims() and not self.is_reduce_max_after_group_conv_einsum():
            return False

        axes = self.get_attribute_by_name("axes")
        axes = axes[0].ints if axes else self.get_initializer_or_constant_value(MIN_INPUT_ORDER)
        axes = [axis for axis in axes if axis] if isinstance(axes, list) else axes
        if not axes:
            return False

        for axis in axes:
            if self.input_format:
                if self.input_format[axis] != Dims.CHANNELS:
                    return False
            elif axis != 1:
                return False

        return True

    def get_hailo_reduce_groups(self):
        pred = next(iter(self._graph.predecessors(self)))

        if not pred.output_format:
            return 1

        if Dims.GROUPS in pred.output_format:
            onnx_input_shapes = self.get_input_shapes(convert_to_nhwc=False)
            return onnx_input_shapes[0][pred.output_format.index(Dims.GROUPS)]

        return 1

    def get_reduce_sum_info(self):
        onnx_input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        onnx_axes = self.get_axes_information()
        if onnx_axes is None:  # no axes means all axes
            onnx_axes = list(range(1, len(onnx_input_shape)))

        pred = next(iter(self._graph.predecessors(self)))
        hailo_axes = []
        groups = 1
        is_valid = True
        interleaved_groups = False
        for onnx_axis in onnx_axes:
            dim = pred.output_format[onnx_axis] if pred.output_format else None
            if dim in [Dims.CHANNELS, Dims.GROUPS] or (dim is None and onnx_axis == 1):
                hailo_axes.append(3)
                if dim == Dims.CHANNELS:
                    groups = self.get_hailo_reduce_groups()
                elif dim == Dims.GROUPS:
                    groups = onnx_input_shape[pred.output_format.index(Dims.CHANNELS)]
                    interleaved_groups = groups != 1
            elif dim == Dims.WIDTH or (dim is None and onnx_axis in [-1, 3]):
                hailo_axes.append(2)
            elif dim == Dims.HEIGHT or (dim is None and onnx_axis in [-2, 2]):
                hailo_axes.append(1)
            elif dim == Dims.BATCH or (dim is None and onnx_axis == 0):
                default_logger().warning(
                    f"Ignoring axis 0 on reduce sum layer {self.name}, reducing will work "
                    "correctly during inference on batch=1 only. It's recommended to run "
                    "evaluation on the full-precision model before optimization to verify "
                    "accuracy is still good enough.",
                )
            else:
                is_valid = False

        return is_valid, hailo_axes, groups, interleaved_groups

    def is_space_to_depth(self):
        # This function now allows only space to depth with block size == 2
        # We expect to see two possible implementations:
        # 1. Basic ONNX implementation of S2D: 4 pairs of slices, where each pair has a different
        #    [height_slice_start, width_slice_start].
        # 2. YoloP variant of S2D: 2 pairs of slices, where the concatenation in the end of the stem covers the entire
        #    tensor.
        if self.op not in SLICE_OPS:
            return False

        pred = next(x for x in self._graph.predecessors(self) if x.op not in CONST_OPS)
        successors = list(self._graph.successors(pred))
        slice_start_pairs = []
        for succ in successors:
            if succ.op not in SLICE_OPS:
                return False
            slices_args, _ = succ.get_slices_args()
            if len(slices_args["axes"]) == len(slices_args["steps"]) == 2:
                if slices_args["axes"] != [2, 3] or slices_args["steps"] != [2, 2]:
                    return False
                next_successors = [succ]
                slice_start_pairs.extend([(slices_args["starts"][0], slices_args["starts"][1])])

            else:
                next_successors = list(self._graph.successors(succ))
                if len(next_successors) not in [1, 2] or next_successors[0].op not in SLICE_OPS:
                    return False

                successors_slice_args = [x.get_slices_args()[0] for x in next_successors]
                if (
                    slices_args["axes"][0] != 2
                    or slices_args["steps"][0] != 2
                    or any(x["axes"][0] != 3 for x in successors_slice_args)
                    or any(x["steps"][0] != 2 for x in successors_slice_args)
                ):
                    return False

                successors_slice_pairs = [(slices_args["starts"][0], x["starts"][0]) for x in successors_slice_args]
                slice_start_pairs.extend(successors_slice_pairs)

            second_level_successors = [list(self._graph.successors(x)) for x in next_successors]
            if (
                (len(second_level_successors) not in [1, 2])
                or any(len(x) != 1 for x in second_level_successors)
                or any(x[0].op not in CONCAT_OPS for x in second_level_successors)
            ):
                return False

            concat_node = second_level_successors[0][0]
            if any(x[0] != concat_node for x in second_level_successors):
                return False

        return len(slice_start_pairs) in [2, 4] and all(
            pair in [(0, 0), (0, 1), (1, 0), (1, 1)] for pair in slice_start_pairs
        )

    def is_reversed_argmax_slice(self):
        if self.op not in SLICE_OPS:
            return False

        slices_args, _ = self.get_slices_args()
        reverse_cond = (
            list(slices_args["starts"]) == [-1]
            and list(slices_args["ends"]) < [-1 * (2**31 - 1)]
            and list(slices_args["axes"]) == [1]
            and list(slices_args["steps"]) == [-1]
        )

        argmax = look_for_node(self._graph, self, [FwdChainNode(op="ArgMax")])
        return argmax is not None and reverse_cond

    def get_argmax_info(self):
        consumed_vertices = []
        argmax = self
        if self.op in SLICE_OPS:
            argmax = look_and_validate(self._graph, self, [FwdChainNode(op="ArgMax")])

        axis = argmax.get_attribute_by_name("axis")
        if not axis or len(axis) != 1:
            raise UnsupportedLogitsLayerError(f"ArgMax layer {argmax.name} has invalid axis.")
        if axis[0].i != 1:
            raise UnsupportedLogitsLayerError(f"ArgMax layer {argmax.name} has unsupported axis {axis[0].i}.")

        consumed_vertices.append(argmax)
        return consumed_vertices

    def get_space_to_depth_info(self):
        concat_node = None
        consumed_vertices = []
        pred = next(x for x in self._graph.predecessors(self) if x.op not in CONST_OPS)
        successors = list(self._graph.successors(pred))

        for succ in successors:
            if succ != self:
                consumed_vertices.append(succ)
            _, vertices = succ.get_slices_args()
            consumed_vertices.extend(vertices)
            next_successors = list(self._graph.successors(succ))
            if any(x.op != "Slice" for x in next_successors):
                concat_node = next_successors[0]
            else:
                consumed_vertices.extend(next_successors)
                for next_succ in next_successors:
                    _, vertices = next_succ.get_slices_args()
                    consumed_vertices.extend(vertices)

                if concat_node is None:
                    concat_node = next(iter(self._graph.successors(next_successors[0])))

        consumed_vertices.append(concat_node)
        return DEFAULT_SPACE_TO_DEPTH_BLOCK_SIZE, consumed_vertices

    def is_flattened_global_maxpool(self):
        if self.op != "Flatten":
            return False

        flatten_axis = self.get_axis()
        if flatten_axis is not None and flatten_axis != 1:
            return False

        reduce_max = look_for_node(self._graph, self, [FwdChainNode(op="ReduceMax")])
        if not reduce_max:
            return False

        reduce_axes = reduce_max.get_attribute_by_name("axes")

        return reduce_axes and len(reduce_axes) == 1 and reduce_axes[0].ints == [1] and reduce_max.get_keepdims()

    def is_flattened_global_avgpool(self):
        if self.op != "Reshape":
            return False
        input_shapes = self.get_input_shapes(convert_to_nhwc=False)
        output_shapes = self.get_output_shapes(convert_to_nhwc=False)
        if not input_shapes or not output_shapes:
            return False

        if not (
            len(output_shapes[0]) == 3
            and len(input_shapes[0]) == 4
            and input_shapes[0][2] * input_shapes[0][3] == output_shapes[0][2]
            and input_shapes[0][1] == output_shapes[0][1]
        ):
            return False

        succs = list(self._graph.successors(self))
        if (
            len(succs) != 1
            or succs[0].op != "ReduceMean"
            or (succs[0].get_attribute_by_name("axes")[0].ints[0] not in [-1, 2])
        ):
            return False

        succs = list(self._graph.successors(succs[0]))
        if len(succs) != 1 or succs[0].op != "Reshape":
            return False

        output_shapes = succs[0].get_output_shapes(convert_to_nhwc=False)
        return bool(
            output_shapes
            and len(output_shapes[0]) == 4
            and output_shapes[0][2] == output_shapes[0][3] == 1
            and output_shapes[0][1] == input_shapes[0][1]
        )

    def get_flattened_pooling_info(self):
        if self.op == "Reshape":
            reduce_mean_node = look_for_node(self._graph, self, [FwdChainNode(op="ReduceMean")])
            second_reshape_node = look_for_node(self._graph, reduce_mean_node, [FwdChainNode(op="Reshape")])
            return second_reshape_node.get_output_shapes(), [reduce_mean_node, second_reshape_node]

        if self.op == "Flatten":
            reduce_max_node = look_for_node(self._graph, self, [FwdChainNode(op="ReduceMax")])
            return reduce_max_node.get_output_shapes(), [reduce_max_node]

        return None

    def is_spatial_1x1_descendant(self):
        if self.is_spatial_1x1:
            return True

        if self.op == "Reshape" and self.is_flat_to_frames_reshape():
            return False

        preds = [pred for pred in self.graph.predecessors(self) if not pred.is_const()]
        if len(preds) == 0:
            return False

        return all(pred.is_spatial_1x1_descendant() for pred in preds)

    def is_multi_head_attention_dense(self):
        if self.op != "Gemm":
            return False, []

        succs = list(self.graph.successors(self))
        if len(succs) != 1:
            return False, []

        succ = succs[0]
        if succ.op != "Reshape":
            return False, []

        output_shapes = succ.get_output_shapes(convert_to_nhwc=False)
        if not output_shapes:
            return False, []

        output_shape = output_shapes[0]
        if len(output_shape) != 3 or output_shape[1] != 1:
            return False, []

        return True, [succ]

    def is_conv1x1_dense(self):
        if self.op not in DENSE_OPS:
            return False

        is_multi_head_attention_dense, _ = self.is_multi_head_attention_dense()
        if is_multi_head_attention_dense:
            return True

        succs = list(self.graph.successors(self))
        if (
            succs
            and (
                all(succ.is_last_layer_activation() for succ in succs)
                or any(succ.is_flat_to_frames_reshape() for succ in succs)
            )
            or self.is_matmul_layer()
        ):
            return False

        if self.is_spatial_1x1_descendant():
            self.is_spatial_1x1 = True
            return True

        pred = next(iter(self._graph.predecessors(self)))
        if pred.op == "Flatten" or (pred.op == "Reshape" and pred.is_flatten_reshape()[0]):
            return False

        input_format, output_format = self.input_format, self.output_format
        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        output_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
        if input_format == output_format and input_shape[:-1] == output_shape[:-1]:
            return True

        return False

    def is_last_layer_activation(self):
        if self.op not in ACTIVATION_OPS + LOGITS_OPS:
            return False

        return len(list(self.graph.successors(self))) == 0

    def is_conv1x1_matmul(self):
        if self.op != "MatMul":
            return False

        output_shapes = self.get_output_shapes(convert_to_nhwc=False)
        if not output_shapes:
            return False

        output_shape = output_shapes[0]
        rank3_output_shape_condition = len(output_shape) == 3 or (len(output_shape) == 4 and output_shape[1] == 1)

        pred_reshape = look_for_node(self.graph, self, [BwdChainNode(op="Reshape")])
        pred_null_reshape = pred_reshape is not None and pred_reshape.is_null_operation()
        concat_other_reshape = look_for_node(self.graph, self, [FwdChainNode("Concat"), BwdChainNode(op="Reshape")])
        concat_null_other_reshape = concat_other_reshape is not None and concat_other_reshape.is_null_operation()

        is_conv1x1 = (
            rank3_output_shape_condition
            or self.output_format == [Dims.WIDTH, Dims.CHANNELS]  # MHA condition
            or (pred_null_reshape and concat_null_other_reshape)
            or self.is_matmul_over_groups()
        )

        return is_conv1x1 and not self.is_matmul_layer()

    def should_transpose_kernel(self):
        if self.op != "Gemm":
            return False

        # Gemm operator has 2 attributes - transA, and transB that toggle transpose for each input.
        # In our case, input A is the data tensor, and input B is the weights' matrix.
        return self.get_attribute_by_name("transB")[0].i == 1

    def is_dilated_conv(self):
        if self.op != "Transpose" or self.get_transpose_perm() != [1, 0, 2, 3]:
            return False

        nodes = get_all_nodes_in_chain(
            self._graph,
            self,
            [
                FwdChainNode(op="SpaceToDepth"),
                FwdChainNode(op="Transpose"),
                FwdChainNode(op="Conv"),
                FwdChainNode(op="Transpose"),
                FwdChainNode(op="DepthToSpace"),
                FwdChainNode(op="Transpose"),
                FwdChainNode(op="Slice"),
                FwdChainNode(op="Transpose"),
            ],
        )
        if not nodes:
            return False

        if nodes[0].get_d2s_block_size() != nodes[4].get_d2s_block_size():
            return False

        perms = [node.get_transpose_perm() for node in nodes if node.op == "Transpose"]
        if perms != [[1, 0, 2, 3], [1, 0, 2, 3], [1, 2, 3, 0], [0, 3, 1, 2]]:
            return False

        slices_args, _ = nodes[-2].get_slices_args()
        return (
            list(slices_args["starts"]) == [0, 0]
            and list(slices_args["ends"]) == [2**31 - 1, 2**31 - 1]
            and list(slices_args["axes"]) == [1, 2]
            and list(slices_args["steps"]) == [1, 1]
        )

    def get_dilated_conv_info(self):
        nodes = get_all_nodes_in_chain(
            self._graph,
            self,
            [
                FwdChainNode(op="SpaceToDepth"),
                FwdChainNode(op="Transpose"),
                FwdChainNode(op="Conv"),
                FwdChainNode(op="Transpose"),
                FwdChainNode(op="DepthToSpace"),
                FwdChainNode(op="Transpose"),
                FwdChainNode(op="Slice"),
                FwdChainNode(op="Transpose"),
            ],
        )
        dilation_rate = nodes[0].get_d2s_block_size()
        dilations = [1, dilation_rate, dilation_rate, 1]
        kernel, _ = nodes[2].get_kernel()

        return kernel, dilations, nodes[2].get_strides(), nodes

    def is_height_to_features_reshape(self):
        if self.op != "Reshape":
            return False

        conv = look_for_node(self._graph, self, [FwdChainNode(op="Conv")])
        if conv is None:
            return False

        reshape_input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        reshape_output_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
        rank4_cond = len(reshape_input_shape) == 4 and len(reshape_output_shape) == 4
        reshape_cond = rank4_cond and (
            reshape_output_shape[0] == reshape_input_shape[0]
            and reshape_output_shape[1] == reshape_input_shape[1] * reshape_input_shape[2]
            and reshape_output_shape[2] == 1
            and reshape_output_shape[3] == reshape_input_shape[3]
        )

        conv_cond = conv.get_kernel_shape() == [1, 1] and conv.get_strides() == [1, 1, 1, 1]
        return reshape_cond and conv_cond

    def get_height_to_features_conv_info(self):
        conv = look_and_validate(self._graph, self, [FwdChainNode(op="Conv")])
        kernel, _ = conv.get_kernel()
        bias, _ = conv.get_bias()

        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        height = input_shape[2]
        strides = [1, height, 1, 1]

        if kernel.shape[1] % height != 0:
            raise UnsupportedConvLayerError(
                f"Tried to convert reshape vertex {self.name} which "
                f"transposes height<->features, but the number of "
                f"output features {kernel.shape} isn't divisible by "
                f"its input height {height}.",
            )

        f_in = int(kernel.shape[1] / height)
        f_out = kernel.shape[0]
        kernel = np.reshape(kernel, [f_out, f_in, height, 1])
        kernel = np.transpose(kernel, [2, 3, 1, 0])  # [k_w, k_h, f_in, f_out] (onnx repr)

        return kernel, bias, strides, [conv]

    def get_dynamic_kernel_shape(self):
        input_shapes = self.get_input_shapes(convert_to_nhwc=False)
        if len(input_shapes) == 2:
            return [input_shapes[1][2], input_shapes[1][3], input_shapes[1][1], input_shapes[1][0]]
        return None

    def is_null_reshape(self):
        if self.op not in ["Reshape", "Flatten"]:
            return False, []

        input_shapes = self.get_input_shapes(convert_to_nhwc=False)
        if not input_shapes:
            return False, []

        if self.is_features_to_stack_with_flat_height_reshape(self.input_format):
            return True, []

        if self.is_features_to_stack_with_flat_groups_reshape(self.input_format):
            return True, []

        # checks if the vertex is null op as part of a chain of nodes
        chain = [FwdChainNode("Transpose"), FwdChainNode("Reshape"), FwdChainNode("Transpose")]
        nodes = get_all_nodes_in_chain(self._graph, self, chain)
        if nodes:
            # ensures all nodes have only one output
            does_have_one_output = all(len(node.get_output_shapes()) == 1 for node in nodes)
            if does_have_one_output and nodes[-1].get_output_shapes() == self.get_input_shapes():
                return True, nodes

        nodes = get_all_nodes_in_chain(self._graph, self, [FwdChainNode(op="Transpose"), FwdChainNode(op="Reshape")])
        if nodes:
            transpose_with_1 = 1 in self.get_output_shapes(convert_to_nhwc=False)[0][1:3]
            rank3 = len(self.get_input_shapes(convert_to_nhwc=False)[0]) == 3
            allowed_perm = nodes[0].get_transpose_perm() == [0, 2, 1, 3]
            does_have_one_output = all(len(node.get_output_shapes()) == 1 for node in nodes)
            same_shape = nodes[-1].get_output_shapes() == self.get_input_shapes()
            if transpose_with_1 and rank3 and allowed_perm and does_have_one_output and same_shape:
                return True, nodes

        nodes = get_all_nodes_in_chain(self._graph, self, [FwdChainNode(op="Transpose")])
        if nodes:
            input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
            transpose = nodes[-1].get_output_shapes(convert_to_nhwc=False)[0]
            if len(input_shape) == 3 and transpose[1:] == input_shape:
                return True, nodes

        chain = [FwdChainNode(x) for x in ["Split", "Transpose", "MatMul", "Transpose", "Reshape"]]
        nodes = get_all_nodes_in_chain(self._graph, self, chain)
        if nodes:
            # checks if split has 3 transpose successors - edge case of transformer
            succs = list(self._graph.successors(nodes[0]))
            if all(succ.op == "Transpose" for succ in succs):
                return True, []

        chain = [BwdChainNode(x) for x in ["Transpose", "MatMul", "Transpose", "Split", "Reshape"]]
        nodes = get_all_nodes_in_chain(self._graph, self, chain)
        if nodes:
            return True, []

        # checks if the vertex is null op by itself
        output_shape, consumed_vertices = self.get_reshape_shapes()
        input_shape = input_shapes[0]
        if input_shape == output_shape:
            return True, consumed_vertices

        # Check if reshape is before conv with dynamic kernel
        succs = list(self._graph.successors(self))
        if (
            len(succs) == 1
            and succs[0].op == "Conv"
            and len(succs[0]._info.input) == 2
            and succs[0]._info.input[1] == self._info.output[0]
            and len(output_shape) == len(input_shape) == 4
            and input_shape[0] == output_shape[1]
            and input_shape[1] == output_shape[0]
            and input_shape[2:] == output_shape[2:]
        ):
            return True, consumed_vertices

        pred = next(iter(self.graph.predecessors(self)))
        is_f_to_h = self._is_features_to_height_reshape(pred.output_format)
        is_flatten_h_to_f = self._is_flatten_height_to_features_reshape(pred.output_format)
        if self.is_features_to_groups_reshape(pred.output_format)[0] and not (is_f_to_h or is_flatten_h_to_f):
            return True, []

        if self.is_groups_to_features_reshape(pred.output_format)[0]:
            return True, []

        # squeeze on any dimension when no input format is available (maybe after input layer)
        if pred.output_format is None and pred.is_null_unsqueeze():
            batch_index = pred.get_axes_information()[0]
            expected_output_shape = input_shape.copy()
            expected_output_shape.pop(batch_index)
            if expected_output_shape == output_shape:
                return True, []

        # squeeze on batch reshape - input format is available
        if self.is_squeeze_on_batch_reshape():
            transpose_output_shape = look_for_node(self._graph, self, [FwdChainNode(op="Transpose")])
            return True, [transpose_output_shape] if transpose_output_shape else []

        if self.is_unsqueeze_on_batch_reshape() and not self.is_attention_windows_to_input_reshape():
            return True, []

        if self.is_reshape_transpose_expand_height_dim():
            return True, []

        if self.is_inner_product_matmul():
            return True, []

        return False, []

    def is_squeeze_on_batch_reshape(self):
        if self.op != "Reshape":
            return False

        input_shapes = self.get_input_shapes(convert_to_nhwc=False)
        output_shapes = self.get_output_shapes(convert_to_nhwc=False)
        if not input_shapes or not output_shapes:
            return False

        input_shape, output_shape = input_shapes[0], output_shapes[0]
        pred = next(iter(self.graph.predecessors(self)))

        if pred.output_format is not None and Dims.BATCH in pred.output_format:
            batch_index = pred.output_format.index(Dims.BATCH)
        else:
            return False

        expected_output_shape = input_shape.copy()
        expected_output_shape.pop(batch_index)

        return expected_output_shape == output_shape

    def is_unsqueeze_on_batch_reshape(self):
        if self.op != "Reshape":
            return False

        input_shapes = self.get_input_shapes(convert_to_nhwc=False)
        output_shapes = self.get_output_shapes(convert_to_nhwc=False)
        if not input_shapes or not output_shapes:
            return False

        input_shape, output_shape = input_shapes[0], output_shapes[0]

        pred = next(iter(self.graph.predecessors(self)))
        if pred.output_format is None:
            return False

        if Dims.BATCH in pred.output_format:
            return False

        if len(input_shape) + 1 != len(output_shape):
            return False

        return self._find_unsqueeze_axis() is not None

    def _find_unsqueeze_axis(self):
        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        output_shape = self.get_output_shapes(convert_to_nhwc=False)[0]

        for i in range(len(output_shape)):
            unsqueezed_on_i_shape = input_shape.copy()
            unsqueezed_on_i_shape.insert(i, 1)
            if unsqueezed_on_i_shape == output_shape:
                return i
        return None

    def get_flatten_chain(self, fwd=False):
        chain_node = FwdChainNode if fwd else BwdChainNode
        consumed_vertices = []
        flatten = look_for_node(self._graph, self, [chain_node(op="Flatten")])
        while flatten:
            consumed_vertices.append(flatten)
            flatten = look_for_node(self._graph, flatten, [chain_node(op="Flatten")])

        return consumed_vertices

    def is_square(self):
        if self.op in POW_OPS:
            power, _ = self.get_power()
            return power == 2.0

        if self.op in MUL_OPS:
            return self.is_square_mul()

        return False

    def is_square_mul(self):
        return len(self.input) == 2 and self.input[0] == self.input[1]

    def get_power(self):
        if self.op != "Pow":
            raise UnexpectedNodeError(f"Power operator is exported as Pow in ONNX, found {self.op} instead.")

        power = 0.0
        consumed_vertices = []
        var_initializers = self._graph.values_by_vertex_name[self.name]
        possible_chains = [[BwdChainNode(op="Constant")], [BwdChainNode(op="Cast")], [BwdChainNode(op="Identity")]]
        y_const = get_node_from_possible_chains(self._graph, self, possible_chains)
        if y_const:
            consumed_vertices.append(y_const)
            power = y_const.parse_raw_data().tolist()
        elif var_initializers:
            power = float(next(iter(var_initializers.values())).tolist())

        return power, consumed_vertices

    def is_ew_op_with_const_input(self):
        # If we find an Add node, that has an initializer (or const) which is of the size of the entire tensor,
        # we cannot create a bias add layer or a normalization layer from it. Therefore, our only option is to
        # create a 'synthetic' input, that will need to be fed manually by the value of the initializer (or the
        # constant)
        if self.op not in EW_OPS:
            return False

        output_shapes = self.get_output_shapes(convert_to_nhwc=False)
        if not output_shapes:
            return False

        output_shape = output_shapes[0]
        if len(output_shape) <= 2:
            return False

        const_node = get_node_from_possible_chains(
            self._graph, self, [[BwdChainNode(op="Constant")], [BwdChainNode(op="Initializer")]]
        )
        variable_val = self._graph.values_by_vertex_name.get(self.name)
        inference_value = self._graph.output_shapes.get(
            self._info.input[0] + "_value",
            self._graph.output_shapes.get(self._info.input[1] + "_value"),
        )

        if const_node is None and not variable_val and inference_value is None:
            return False

        if const_node:
            values = const_node.parse_raw_data()
        elif inference_value is not None:
            values = inference_value
        else:
            values = next(iter(variable_val.values()))
        values_shape = list(np.shape(values))

        if self.output_format:
            if Dims.CHANNELS in self.output_format:
                channels_shape = output_shape[self.output_format.index(Dims.CHANNELS)]
            else:
                channels_shape = 1

            # Add an extra batch dimension to values shape if needed
            if len(values_shape) == len(self.output_format) - 1 and self.output_format[0] == Dims.BATCH:
                values_shape = [1, *values_shape]

            if len(values_shape) == 1 and values_shape[0] not in [1, channels_shape]:
                for dim, dim_shape in zip(self.output_format, output_shape):
                    if dim_shape == values_shape[0] and dim in [Dims.HEIGHT, Dims.WIDTH]:
                        return True
            elif len(self.output_format) == len(values_shape):
                for dim, dim_shape in zip(self.output_format, values_shape):
                    if dim in [Dims.HEIGHT, Dims.WIDTH] and dim_shape != 1:
                        return True
                return False

        if len(values_shape) < 2:
            return False

        # checks rank spatial condition
        value_rank = len(values_shape)
        general_spatial_cond = (value_rank in [2, 3] and values_shape[:] == output_shape[-value_rank:]) or (
            value_rank == 4
            and (
                (values_shape[1:] == output_shape[1:])
                or (values_shape[1] == 1 and values_shape[2:] == output_shape[2:])
            )
        )
        if general_spatial_cond:
            return True

        transformer_encoder_values_cond = len(values_shape) > 1 and values_shape[1:] == output_shape[1:]
        transformer_encoder_values_rank2_cond = len(values_shape) == 2 and values_shape == output_shape[1:]
        transformer_decoder_values_cond = len(values_shape) == 3 and values_shape[2:] == output_shape[1:]
        pred = next(iter(self._graph.predecessors(self)))
        multihead_attn_batch_first_cond = (
            (pred.output_format and pred.output_format[0] != Dims.BATCH)
            and len(values_shape) == 3
            and values_shape[0] == output_shape[1]
        )
        return (
            transformer_encoder_values_cond
            or transformer_encoder_values_rank2_cond
            or transformer_decoder_values_cond
            or multihead_attn_batch_first_cond
        )

    def get_const_input_values(self):
        if self.op not in [*EW_OPS, "Expand", "Concat", *SCATTER_ND_OPS]:
            return None

        const_node = look_for_node(self._graph, self, [BwdChainNode(op="Constant")])
        params = self._graph.values_by_vertex_name.get(self.name)
        inference_value = self._graph.output_shapes.get(
            self._info.input[0] + "_value",
            self._graph.output_shapes.get(self._info.input[1] + "_value"),
        )
        if const_node is not None:
            return const_node.parse_raw_data()

        if params:
            return next(iter(params.values()))

        if inference_value is not None:
            return inference_value

        return None

    def is_concat_with_new_input(self):
        if self.op != "Concat":
            return False

        preds = list(self._graph.predecessors(self))
        expand_pred_cond = len(preds) == 2 and any(x.op == "Expand" for x in preds)
        const_pred_cond = len(preds) > 1 and any(x.op == "Constant" for x in preds)

        vertex_params_cond = bool(self._graph.values_by_vertex_name.get(self.name))

        return expand_pred_cond or const_pred_cond or vertex_params_cond

    def is_spatial_concat(self, rank4_dim):
        if self.op != "Concat":
            return False

        axis = self.get_axis()
        return self.output_format and self.output_format[axis] == rank4_dim

    def is_spatial_h_concat(self):
        # rank 3 can't be spatial h
        return self.is_spatial_concat(Dims.HEIGHT)

    def is_spatial_w_concat(self):
        return self.is_spatial_concat(Dims.WIDTH)

    def get_new_concat_input(self):
        expand_preds = [x for x in self._graph.predecessors(self) if x.op == "Expand"]
        if len(expand_preds) == 1:
            return expand_preds[0], self.get_all_shape_nodes()
        return False, []

    def get_const_layer_input_order(self):
        input_indices = list(self._info.input)
        num_inputs = len(input_indices)
        preds = list(self._graph.predecessors(self))
        new_order = [None] * num_inputs

        if num_inputs == 2:
            new_order = [None, None]
            non_const_pred = next(x for x in preds if not x.is_const())

            for out in non_const_pred.output:
                if out in input_indices:
                    non_const_index = input_indices.index(out)
                    if non_const_index > -1:
                        new_input_index = 2 - non_const_index - 1
                        non_const_input = f"{non_const_pred.name}:{input_indices[non_const_index]}"
                        new_input = f"{self.name}_input:{input_indices[new_input_index]}"
                        new_order[new_input_index] = new_input
                        new_order[non_const_index] = non_const_input
                        break
        else:
            for pred in preds:
                key = pred.name if pred.op in INPUT_OPS else pred.output[0]
                index = input_indices.index(key)
                new_input = f"{self.name}_input:{index}" if pred.is_const() else f"{pred.name}:{index}"
                new_order[index] = new_input

        if len(preds) < num_inputs:
            params = self._graph.values_by_vertex_name.get(self.name)
            for key in params:
                if key in input_indices:
                    index = input_indices.index(key)
                    new_input = f"{self.name}_input:{index}"
                    new_order[index] = new_input

        return new_order

    def get_concat_new_input(self):
        single_pred = next(iter(self._graph.predecessors(self)))
        new_input_index = self._info.input[0]
        non_const_index = self._info.input[1]
        return [self.name + "_input:" + new_input_index, single_pred.name + ":" + non_const_index]

    def is_shape_expand_norm(self):
        if self.op not in ["Mul", "Add"]:
            return False, []

        consumed_vertices = get_all_nodes_in_chain(
            self._graph,
            self,
            [BwdChainNode(op="Expand"), BwdChainNode(op="Shape")],
        )
        if consumed_vertices:
            expand_node, shape_node = consumed_vertices
            expand_value = self._graph.values_by_vertex_name.get(expand_node.name)
            if expand_value:
                expand_initializer_shape = next(iter(expand_value.values())).shape
                if expand_initializer_shape[0] == expand_initializer_shape[2] == expand_initializer_shape[3] == 1:
                    return True, [shape_node, expand_node]

        return False, []

    def get_input_layer_shapes(self):
        onnx_shapes = self.get_output_shapes(convert_to_nhwc=False)
        onnx_shape = onnx_shapes[0] if onnx_shapes else self._graph.tensor_shapes_by_vertex_name[self.name]
        rank = len(onnx_shape)
        if rank == 4 and (onnx_shape[2] < 1 or onnx_shape[3] < 1):
            # net_input_shapes should be in NCHW format
            shape_example = str([*onnx_shape[:2], 224, 224])
            raise UnsupportedInputShapesError(
                f"Unsupported dynamic shape found on input node {self.name}, consider using start_node_names + "
                f"net_input_shapes arguments when calling translator, e.g. "
                f"runner.translate_onnx_model(..., start_node_names=[{self.name}], net_input_shapes={shape_example})",
                shape_example,
            )

        # fallback lookup name for cases where net_input_format matches start_node_names further from the model input
        lookup_name = (
            self.name if self.name in self._graph.net_input_format else next(iter(self._graph.successors(self))).name
        )
        self.output_format = self._graph.net_input_format.get(lookup_name, DEFAULT_FORMAT_BY_RANK.get(rank))
        if len(self.output_format) != rank:
            msg = f"Input format for {self.name}: {self.output_format} doesn't match its rank ({rank})."
            raise UnsupportedInputFormatError(msg, recommendation=DEFAULT_FORMAT_BY_RANK.get(rank))

        hn_shape = self.convert_nchw_to_nhwc(onnx_shape, self.output_format) if self.output_format else onnx_shape
        return [[-1, *hn_shape[1:]]]

    def is_pre_layer_op(self):
        return self.op in PRE_LAYER_OPS

    def is_inner_product_matmul(self):
        if self.op not in ["Reshape", "MatMul"]:
            return False

        if self.op == "Reshape":
            chains = [[FwdChainNode(op="Transpose"), FwdChainNode(op="MatMul")]]
            nodes = get_all_nodes_from_possible_chains(self._graph, self, chains)
            neighbors = list(self._graph.successors(self))
        else:
            chains = [[BwdChainNode(op="Transpose"), BwdChainNode(op="Reshape")]]
            nodes = get_all_nodes_from_possible_chains(self._graph, self, chains)
            neighbors = list(self._graph.predecessors(self))

        return nodes is not None and len(nodes) == 2 and len(neighbors) == 2 and all(x in neighbors for x in nodes)

    def is_reshape_before_einsum(self):
        if self.op != "Reshape":
            return False
        succs = list(self._graph.successors(self))
        return len(succs) == 1 and succs[0].op == "Einsum" and not succs[0].is_group_conv_einsum()

    def get_reshape_before_einsum_info(self):
        groups = 1
        output_shapes = self.get_output_shapes()
        if len(output_shapes) == 1 and len(output_shapes[0]) == 5:
            groups = output_shapes[0][1]
        return groups

    def is_einsum_conv1x1(self):
        if self.op != "Einsum":
            return False

        equation = self.get_attribute_by_name("equation")[0].s.decode("utf-8")
        if equation not in EINSUM_SUPPORTED_EQUATIONS:
            raise UnsupportedEinsumLayerError(
                f"Layer {self.name} has unsupported equation: {equation}. "
                f"Currently supporting: {EINSUM_SUPPORTED_EQUATIONS}",
            )

        preds = [x for x in self._graph.predecessors(self) if x.op not in CONST_OPS]

        return len(preds) == 1

    def is_group_conv_einsum(self, equation=None):
        equation = equation if equation else self.get_attribute_by_name("equation")[0].s.decode("utf-8")
        return equation == "bmchw,bnmc->bmhwn"

    def get_einsum_info(self):
        equation = self.get_attribute_by_name("equation")[0].s.decode("utf-8")
        kernel = self._get_einsum_kernel()
        return kernel, equation

    def _get_einsum_kernel(self):
        kernel = None
        values = self._graph.values_by_vertex_name[self.name]
        for param in EINSUM_INPUT_ORDER[1:]:
            index = EINSUM_INPUT_ORDER.index(param)
            param_input = self._info.input[index]
            if param_input in values:
                # Try to get the values from variable initializer inputs
                kernel = values[param_input]

        if kernel is None:
            const = look_for_node(self._graph, self, [BwdChainNode("Constant")])
            if const is not None:
                kernel = const.parse_raw_data()
        return kernel

    def is_reduce_max_after_group_conv_einsum(self):
        if self.op != "ReduceMax":
            return False
        einsum = look_for_node(self._graph, self, [BwdChainNode("Einsum")])
        return einsum and einsum.is_group_conv_einsum()

    def is_null_transpose(self):
        """
        Check if the transpose is first and last in a chain of single successors, starting and ending with transpose.
        In that case we ignore both transposes.
        """
        if self.op != "Transpose":
            return False

        if self in self._graph.null_transposes:
            return True

        if self.is_rnn_sequence() or self.is_f_to_w_transpose_reshape() or self.is_dilated_conv():
            return False

        # check edge case for transpose after reshape and split in context of attention
        if self.is_null_transpose_attention_block():
            return True

        # check edge case for transpose that moves the batch dimension in multihead-attention ops
        if self.is_batch_moving_transpose() and not self.is_space_to_depth_transpose_reshape()[0]:
            return True

        # check edge case for torch resize-nearest layer chain
        if self.is_torch_tile_resize_nearest() or self.is_null_transpose_near_torch_tile():
            return True

        if (
            self.is_transpose_before_matmul()
            or self.is_transpose_after_spatial_flatten()
            or self.is_transpose_1d_already_flattened()
            or self.is_transpose_before_spatial_flatten()
            or self.is_nhwc_to_nchw_transpose()
            or self.is_nchw_to_nhwc_transpose()
            or self.is_transpose_after_features_to_groups()
            or self.is_transpose_before_groups_to_features()
            or self.is_heads_width_transpose()
            or self.is_transpose_adjacent_to_layer_norm()
            or self.is_transpose_connected_to_softmax()
            or self.is_feature_projection_transpose()
            or self.is_transpose_after_groups_to_spatial_flatten()
            or self.is_transpose_after_spatial_flatten_to_groups()
            or self.is_transpose_before_spatial_flatten_and_groups_to_features()
            or self.is_transpose_before_conv3d()
        ):
            return True

        reshape = look_for_node(self._graph, self, [BwdChainNode(op="Reshape")])
        if reshape:
            is_null_reshape, null_reshape_nodes = reshape.is_null_reshape()
            if is_null_reshape and self is null_reshape_nodes:
                return True

        # If the spatial unflatten block is transpose -> BN -> reshape we skip the transpose, build BN and build
        # spatial unflatten from reshape
        chain = [FwdChainNode("BatchNormalization"), FwdChainNode("Reshape")]
        if self.is_spatial_unflatten() and look_for_node(self._graph, self, chain) is not None:
            return True

        succs = list(self._graph.successors(self))
        return len(succs) > 0 and all(self._find_null_transpose_chain([succ]) for succ in succs)

    def _find_null_transpose_chain(self, curr_succs):
        self_perm = self.get_transpose_perm()
        cur_vertex = self
        # Iterate over successors with one pred and one succ, excluding a closing transpose with multiple succs
        while curr_succs:
            if cur_vertex != self:
                if cur_vertex.op != "Transpose":
                    if cur_vertex.is_conv1x1_matmul() and self.is_width_features_transpose():
                        return False
                    cur_vertex = curr_succs[0]
                else:
                    # composing the permutations should get us back to the original order
                    curr_perm = cur_vertex.get_transpose_perm()
                    if len(curr_perm) == len(self_perm) and [curr_perm[i] for i in self_perm] == sorted(curr_perm):
                        self._graph._null_transposes.extend([self, cur_vertex])
                        return True
                    return False
            else:
                cur_vertex = curr_succs[0]

            curr_succs = list(self._graph.successors(cur_vertex))
            if len(list(self._graph.predecessors(cur_vertex))) != 1 or len(curr_succs) != 1 and self.op != "Transpose":
                return False

        return False

    def is_feature_projection_transpose(self):
        perm = self.get_transpose_perm()
        if perm != [0, 2, 1]:
            return False

        if not self.is_width_features_transpose():
            return False

        fwd_transpose = look_for_node(
            self.graph,
            self,
            [
                FwdChainNode(op="Conv"),
                FwdChainNode(op="Slice"),
                FwdChainNode(op="Mul"),
                FwdChainNode(op="Mul"),
                FwdChainNode(op="Transpose"),
            ],
        )

        bwd_transpose = look_for_node(
            self.graph,
            self,
            [
                FwdChainNode(op="Mul"),
                FwdChainNode(op="Mul"),
                FwdChainNode(op="Slice"),
                FwdChainNode(op="Conv"),
                FwdChainNode(op="Transpose"),
            ],
        )

        if fwd_transpose is None and bwd_transpose is None:
            return False

        if fwd_transpose is not None and not fwd_transpose.is_width_features_transpose():
            return False

        if bwd_transpose is not None and not bwd_transpose.is_width_features_transpose():
            return False

        return True

    def is_transpose_adjacent_to_layer_norm(self):
        if self.op != "Transpose":
            return False

        succs = list(self._graph.successors(self))
        if succs and all(node.is_layer_norm(ignore_axis=True) for node in succs):
            return True

        div = look_for_node(
            self._graph,
            self,
            [BwdChainNode(op="Add"), BwdChainNode(op="Mul"), BwdChainNode(op="Div"), BwdChainNode(op="Sub")],
        )
        return bool(div and div.is_layer_norm())

    def is_transpose_connected_to_softmax(self):
        return (
            self.op == "Transpose"
            and self.get_transpose_perm() == [0, 3, 2, 1]
            and get_node_from_possible_chains(
                self.graph, self, [[BwdChainNode(op="Softmax")], [FwdChainNode(op="Softmax")]]
            )
            is not None
        )

    def is_transpose_after_groups_to_spatial_flatten(self):
        if self.op != "Transpose":
            return False

        reshape = look_for_node(self._graph, self, [BwdChainNode(op="Reshape")])
        return reshape and reshape.is_groups_to_spatial_flatten()

    def is_transpose_after_spatial_flatten_to_groups(self):
        if self.op != "Transpose":
            return False

        reshape = look_for_node(self._graph, self, [BwdChainNode(op="Reshape")])
        return reshape and reshape.is_spatial_flatten_to_groups()

    def is_transpose_before_spatial_flatten_and_groups_to_features(self):
        if self.op != "Transpose":
            return False
        reshape = look_for_node(self._graph, self, [FwdChainNode(op="Reshape")])
        return reshape and reshape.is_spatial_flatten_and_groups_to_features()

    def is_rnn_sequence(self):
        rnn_op_exists = any(look_for_node(self._graph, self, [FwdChainNode(op=x)]) for x in RNN_SEQ_OPS)
        if rnn_op_exists:
            pred = next(iter(self._graph.predecessors(self)))
            input_format = pred.output_format
            perm = self.get_transpose_perm()
            if input_format == [Dims.BATCH, Dims.CHANNELS, Dims.WIDTH] and perm == [1, 0, 2]:
                return True
        return False

    def is_null_transpose_attention_block(self):
        possible_chains = [
            [
                BwdChainNode(op="Split"),
                FwdChainNode(op="Transpose"),
                FwdChainNode(op="MatMul"),
                FwdChainNode(op="Transpose"),
                FwdChainNode(op="Reshape"),
            ],
            [
                BwdChainNode(op="MatMul"),
                BwdChainNode(op="Transpose"),
                BwdChainNode(op="Split"),
                BwdChainNode(op="Reshape"),
            ],
            [
                BwdChainNode(op="Softmax"),
                FwdChainNode(op="Transpose"),
                FwdChainNode(op="MatMul"),
                FwdChainNode(op="Reshape"),
            ],
            [
                FwdChainNode(op="MatMul"),
                FwdChainNode(op="Add"),
                FwdChainNode(op="Reshape"),
                FwdChainNode(op="Transpose"),
                FwdChainNode(op="Reshape"),
                FwdChainNode(op="MatMul"),
                FwdChainNode(op="Reshape"),
            ],
        ]
        nodes = get_all_nodes_from_possible_chains(self._graph, self, possible_chains)
        return nodes is not None

    def is_heads_width_transpose(self):
        perm = self.get_transpose_perm()
        if perm == [0, 2, 1, 3] and self.output_format == [Dims.BATCH, Dims.GROUPS, Dims.WIDTH, Dims.CHANNELS]:
            return True
        if perm == [2, 0, 3, 1, 4] and self.output_format == [
            Dims.STACK,
            Dims.BATCH,
            Dims.GROUPS,
            Dims.WIDTH,
            Dims.CHANNELS,
        ]:
            return True
        return bool(perm == [1, 0, 2] and self.output_format == [Dims.WIDTH, Dims.GROUPS, Dims.CHANNELS])

    def is_batch_moving_transpose(self):
        perm = self.get_transpose_perm()
        if perm in [[1, 0, 2], [1, 2, 0], [2, 0, 1], [3, 1, 2, 0]]:
            return True

        if perm in [[0, 2, 1, 3], [3, 1, 2, 0]]:
            unsqueeze = look_for_node(self._graph, self, [BwdChainNode(op="Unsqueeze")])
            if unsqueeze is not None:
                unsqueeze_pred = next(iter(self._graph.predecessors(unsqueeze)))
                return unsqueeze_pred.output_format == [Dims.WIDTH, Dims.BATCH, Dims.CHANNELS]

        return False

    def is_last_transpose_attention_block(self):
        if self.op == "Transpose" and self.get_transpose_perm() in [[1, 0, 2], [1, 2, 0]]:
            add_transpose_chains = [
                [BwdChainNode(op="Add"), BwdChainNode(op="Transpose")],
                [FwdChainNode(op="Add"), BwdChainNode(op="Add"), FwdChainNode(op="Transpose")],
                [BwdChainNode(op="Add"), BwdChainNode(op="Add"), BwdChainNode(op="Add"), BwdChainNode(op="Transpose")],
                [
                    BwdChainNode(op="Add"),
                    BwdChainNode(op="Add"),
                    BwdChainNode(op="Add"),
                    BwdChainNode(op="MatMul"),
                    BwdChainNode(op="Reshape"),
                    BwdChainNode(op="Transpose"),
                ],
                [
                    BwdChainNode(op="Slice"),
                    BwdChainNode(op="Add"),
                    BwdChainNode(op="Add"),
                    BwdChainNode(op="Add"),
                    BwdChainNode(op="Concat"),
                    BwdChainNode(op="Transpose"),
                ],
                [
                    BwdChainNode(op="Reshape"),
                    BwdChainNode(op="Gemm"),
                    BwdChainNode(op="Reshape"),
                    BwdChainNode(op="Transpose"),
                ],
                [BwdChainNode(op="Squeeze"), BwdChainNode(op="LSTM"), BwdChainNode(op="Transpose")],
                [BwdChainNode(op="Squeeze"), BwdChainNode(op="RNN"), BwdChainNode(op="Transpose")],
            ]

            transpose_node = get_node_from_possible_chains(self._graph, self, add_transpose_chains)
            return transpose_node is not None and transpose_node.get_transpose_perm() in [[1, 0, 2], [2, 0, 1]]
        return False

    def is_gcn_block_transpose(self):
        if self.op != "Transpose" or self.get_transpose_perm() != [0, 1, 3, 2]:
            return False
        first_in_block = get_all_nodes_in_chain(
            self._graph,
            self,
            [
                FwdChainNode(op="Concat"),
                FwdChainNode(op="Conv"),
                FwdChainNode(op="HardSigmoid"),
                FwdChainNode(op="Mul"),
                FwdChainNode(op="Slice"),
                FwdChainNode(op="Transpose"),
            ],
        )
        last_in_block = get_all_nodes_in_chain(
            self._graph,
            self,
            [
                BwdChainNode(op="Slice"),
                BwdChainNode(op="Mul"),
                BwdChainNode(op="HardSigmoid"),
                BwdChainNode(op="Conv"),
                BwdChainNode(op="Concat"),
                BwdChainNode(op="Transpose"),
            ],
        )
        return first_in_block is not None or last_in_block is not None

    def is_hc_transpose(self):
        if self.op != "Transpose":
            return False
        pred = next(iter(self._graph.predecessors(self)))
        input_format = pred.output_format
        perm = self.get_transpose_perm()
        return (
            input_format == [Dims.BATCH, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH]
            and len(perm) == 4
            and [input_format[i] for i in perm] == [Dims.BATCH, Dims.HEIGHT, Dims.CHANNELS, Dims.WIDTH]
        )

    def get_transpose_perm(self):
        if self.op != "Transpose":
            return None
        return self.get_attribute_by_name("perm")[0].ints

    def is_ew_mean(self):
        if self.op != "Mean":
            return False

        return len(list(self.graph.predecessors(self))) > 1

    def is_avgpool_reduce_mean(self):
        if self.op != "ReduceMean":
            return False

        axis = self.get_attribute_by_name("axes")
        axis = list(axis[0].ints) if axis else self.get_initializer_or_constant_value(MIN_INPUT_ORDER)
        input_shape = self.get_input_shapes(convert_to_nhwc=True)[0]

        # axes is only width and height is not 1
        if any(np.array_equal(axis, arr) for arr in [np.array([-1]), np.array([3])]) and input_shape[2] > 1:
            return True

        # axes is only height and width is not 1
        return any(np.array_equal(axis, arr) for arr in [np.array([-2]), np.array([2])]) and input_shape[3] > 1

    def get_avgpool_reduce_mean_info(self):
        axis = self.get_attribute_by_name("axes")
        axis = axis[0].ints[0] if axis else self.get_initializer_or_constant_value(MIN_INPUT_ORDER)[0]
        input_shape = self.get_input_shapes(convert_to_nhwc=True)[0]
        pred = next(iter(self.graph.predecessors(self)))
        reduce_dim = pred.output_format[axis] if pred.output_format else None

        if (reduce_dim is None and axis == 2) or reduce_dim == Dims.HEIGHT:
            kernel_shape = [1, input_shape[1], 1, 1]
            stride = [1, input_shape[1], 1, 1]
        elif (reduce_dim is None and axis in [-1, 3]) or reduce_dim == Dims.WIDTH:
            kernel_shape = [1, 1, input_shape[2], 1]
            stride = [1, 1, input_shape[2], 1]
        else:
            err_msg = (
                f"Reduce mean layer {self.name} has unsupported axis {axis} (must be over one spatial dimension only)."
            )
            raise UnsupportedReduceMeanLayerError(err_msg)
        return kernel_shape, stride

    def get_avgpool_count_include_pad(self):
        # this flag indicates whether to include the padding in the avg calculation the default value is False
        count_include_pad = self.get_attribute_by_name("count_include_pad")
        return bool(count_include_pad[0].i if len(count_include_pad) > 0 else 0)

    def get_ceil_mode(self, op):
        # When the flag is True, will use ceil instead of floor to compute the output shape
        ceil_mode = self.get_attribute_by_name("ceil_mode")
        if not (len(ceil_mode) > 0 and bool(ceil_mode[0].i)):
            return False

        # the flag is True, checks if there is a shape diff between ceil mode and floor mode if not turns off the flag
        _, pad_beg_h, pad_beg_w, _, pad_end_h, pad_end_w = self.get_vertex_padding()[1]

        h_in, w_in = self.get_input_shapes()[0][1:3]
        kernel_h, kernel_w = self.get_kernel_shape()
        stride_h, stride_w = self.get_strides()[1:3]

        # calculates the output shapes according to the pooling type
        if op == LayerType.avgpool:
            h_out_numerator = h_in + pad_beg_h + pad_end_h - kernel_h
            w_out_numerator = w_in + pad_beg_w + pad_end_w - kernel_w
        else:
            # op is maxpool
            dilation_h, dilation_w = self.get_dilations()[1:3]
            h_out_numerator = h_in + pad_beg_h + pad_end_h - dilation_h * (kernel_h - 1) - 1
            w_out_numerator = w_in + pad_beg_w + pad_end_w - dilation_w * (kernel_w - 1) - 1
        h_out = (h_out_numerator / stride_h) + 1
        w_out = (w_out_numerator / stride_w) + 1
        return not (h_out.is_integer() and w_out.is_integer())

    def get_reduce_mean_info(self):
        """
        Returns an empty list if it's not a valid reduce mean, else the relevant axes
        Valid = reduces the channels dim (can reduce also spatial dims)
        """
        axes, consumed_vertices, groups = [], [], 1
        dims_to_hailo_axes = {Dims.CHANNELS: 3, Dims.HEIGHT: 1, Dims.WIDTH: 2}
        pred = next(iter(self._graph.predecessors(self)))
        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        if pred.output_format:
            axes = [dims_to_hailo_axes[pred.output_format[i]] for i in self.get_axes_information()]
            if Dims.GROUPS in pred.output_format:
                groups = input_shape[pred.output_format.index(Dims.GROUPS)]

        reshape_node = look_for_node(self._graph, self, [FwdChainNode(op="Reshape")])
        if reshape_node and reshape_node.is_nhw_to_nchw_reshape(self.output_format):
            # reduce mean with keep_dims=False but has followed reshape to rank 4
            consumed_vertices = [reshape_node]

        return axes, groups, consumed_vertices

    def is_reduce_mean_layer(self):
        if self.op != "ReduceMean":
            return False

        axes = self.get_axes_information()
        input_shapes = self.get_input_shapes(convert_to_nhwc=False)
        output_shapes = self.get_output_shapes(convert_to_nhwc=False)
        if not input_shapes or not output_shapes:
            return False
        input_shape, output_shape = input_shapes[0], input_shapes[0]

        pred = next(iter(self._graph.predecessors(self)))
        if pred.output_format:
            if any(pred.output_format[i] in [Dims.CHANNELS, Dims.GROUPS] for i in axes):
                return True
        elif axes == [-1]:
            if (len(input_shape) == 3 and len(output_shape) == 3) or len(input_shape) == 2:
                # edge case of transformers: we ignore a transpose before transformer encoder block,
                # therefore axis=[-1] stands for channels.
                return True
        elif axes == [1]:
            if len(input_shape) == 4 and len(output_shape) == 4:
                return True
            reshape_node = look_for_node(self._graph, self, [FwdChainNode(op="Reshape")])
            if reshape_node and reshape_node.is_nhw_to_nchw_reshape(self.output_format):
                # reduce mean with keep_dims=False but has followed reshape to rank 4
                return True
        return False

    def is_inv_pos_activation(self):
        if self.op not in DIV_OPS:
            return False, []

        if self.is_ew_div():
            return False, []

        value_input = self._info.input[0]
        consumed_vertices = []

        # Try to get the values from variable initializer inputs
        vertex_values = self._graph.values_by_vertex_name.get(self.name, {})
        if value_input in vertex_values:
            value = vertex_values[value_input]
            return value.shape == () and float(value) == 1.0, []

        # Try to get the values from constant inputs
        if value_input in self._graph.vertices_by_inp_key:
            input_vertex = self._graph.vertices_by_inp_key[value_input]
            consumed_vertices.append(input_vertex)
            if input_vertex.op == "Cast":
                input_vertex = self._graph.vertices_by_inp_key.get(input_vertex.input[0])
                consumed_vertices.append(input_vertex)

            if input_vertex and input_vertex.is_const():
                value = input_vertex.parse_raw_data()
                if value.shape == () and float(value) == 1.0:
                    return True, consumed_vertices

        return False, []

    def get_nms_config_values(self):
        if self.op not in NMS_OPS:
            return {}
        config_update = {}
        vertex_values = self._graph.values_by_vertex_name[self.name]
        for vertex_key, val in vertex_values.items():
            nms_key = [nms_key for nms_key in NMS_ORDER if nms_key in vertex_key]
            if len(nms_key) == 1:
                val = val[0]
                if val == -np.inf:
                    val = 0
                # json schema doesn't accept np types
                val = int(val) if isinstance(val, np.integer) else val
                val = float(val) if isinstance(val, np.floating) else val
                config_update[NMS_ONNX_KEY_TO_CONFIG_KEY[nms_key[0]].value] = val
        return config_update

    def get_nms_last_conv_info(self, spatial_input_shape):
        if self.op not in CONV2D_OPS:
            return {}

        start = [FwdChainNode(op="Reshape"), FwdChainNode(op="Transpose")]
        possible_middles = [
            [FwdChainNode(op="Sigmoid"), FwdChainNode(op="Slice")],
            [FwdChainNode(op="Split"), FwdChainNode(op="Sigmoid")],
            [FwdChainNode(op="Sigmoid"), FwdChainNode(op="Split")],
            [FwdChainNode(op="Sigmoid"), FwdChainNode(op="ScatterND"), FwdChainNode(op="Slice")],
        ]
        end = [FwdChainNode(op="Mul"), FwdChainNode(op="Pow"), FwdChainNode(op="Mul")]

        full_chains = [start + middle + end for middle in possible_middles]
        mul_vertex = get_node_from_possible_chains(self._graph, self, full_chains)
        anchors_divide_factor = 1
        if not mul_vertex:
            # maybe there is an end chain with anchors divide factor
            end = [FwdChainNode(op="Pow"), FwdChainNode(op="Mul")]
            full_chains = [start + middle + end for middle in possible_middles]
            mul_vertex = get_node_from_possible_chains(self._graph, self, full_chains)
            anchors_divide_factor = 4
        if mul_vertex:
            return self._extract_anchors_and_strides(mul_vertex, spatial_input_shape, anchors_divide_factor)
        return {}

    def _extract_anchors_and_strides(self, mul_vertex, spatial_input_shape, anchors_divide_factor):
        const = look_for_node(self._graph, mul_vertex, [BwdChainNode(op="Constant")])
        if const:
            raw_data = const.parse_raw_data()
        else:
            var_initializer = self._graph.values_by_vertex_name.get(mul_vertex.name, None)
            if var_initializer:
                raw_data = next(iter(var_initializer.values()))
            else:
                return {}
        uniques = np.unique(raw_data, return_index=True)
        idx_value_tuples = [(idx, val) for val, idx in zip(uniques[0], uniques[1])]
        anchors = [float(tup[1]) / anchors_divide_factor for tup in sorted(idx_value_tuples, key=lambda x: x[0])]
        stride = spatial_input_shape // self.get_output_shapes()[0][1]
        return {self.name: {"w": anchors[::2], "h": anchors[1::2], "stride": stride}}

    def get_yolov6_yolox_reshape_nodes(self):
        if self.op != "Transpose":
            return None
        preds = list(self._graph.predecessors(self))
        if len(preds) == 1 and preds[0].op == "Concat":
            concat_preds = list(self._graph.predecessors(preds[0]))
            if len(concat_preds) > 0 and all(pred.op == "Reshape" for pred in concat_preds):
                return concat_preds
        return None

    def is_end_of_yolox_structure(self):
        end_nodes = []
        reshape_nodes = self.get_yolov6_yolox_reshape_nodes()
        if reshape_nodes:
            for reshape in reshape_nodes:
                chain1 = get_all_nodes_in_chain(
                    self._graph,
                    reshape,
                    [BwdChainNode(op="Concat"), BwdChainNode(op="Conv")],
                )
                chain2 = get_all_nodes_in_chain(
                    self._graph,
                    reshape,
                    [BwdChainNode(op="Concat"), BwdChainNode(op="Sigmoid")],
                )
                if chain1 is None or chain2 is None:
                    return False, []

                end_nodes += [pred.name for pred in self._graph.predecessors(chain1[0])]

            return True, end_nodes

        return False, []

    def is_end_of_yolov6_structure(self, second_vertex):
        reshape_nodes1 = self.get_yolov6_yolox_reshape_nodes()
        reshape_nodes2 = second_vertex.get_yolov6_yolox_reshape_nodes()
        if not reshape_nodes1 or not reshape_nodes2:
            return False, []
        reshapes = reshape_nodes1 + reshape_nodes2
        end_nodes = [next(iter(self._graph.predecessors(reshape))).name for reshape in reshapes]
        non_sigmoid_chain = [BwdChainNode("Concat"), BwdChainNode("Reshape")]
        sigmoid_chain = [*non_sigmoid_chain, BwdChainNode("Sigmoid")]
        vertices = [self, second_vertex]
        for i in range(2):
            curr = vertices[i]
            second = vertices[1 - i]
            if get_all_nodes_in_chain(self._graph, curr, sigmoid_chain) and get_all_nodes_in_chain(
                self._graph,
                second,
                non_sigmoid_chain,
            ):
                return True, end_nodes
        return False, []

    def get_possible_yolov8_postprocess(self):
        if self.op not in CONV2D_OPS + CONCAT_OPS:
            return None

        chains = [
            [
                FwdChainNode("Concat"),
                FwdChainNode("Reshape"),
                FwdChainNode("Concat"),
                FwdChainNode("Split"),
                FwdChainNode("Reshape"),
                FwdChainNode("Transpose"),
                FwdChainNode("Softmax"),
                FwdChainNode("Conv"),
            ],
            [
                FwdChainNode("Concat"),
                FwdChainNode("Reshape"),
                FwdChainNode("Concat"),
                FwdChainNode("Split"),
                FwdChainNode("Reshape"),
                FwdChainNode("Transpose"),
                FwdChainNode("Softmax"),
                FwdChainNode("Transpose"),
                FwdChainNode("Conv"),
            ],
            [
                FwdChainNode("Split"),
                FwdChainNode("Reshape"),
                FwdChainNode("Transpose"),
                FwdChainNode("Softmax"),
                FwdChainNode("Conv"),
            ],
            [
                FwdChainNode("Split"),
                FwdChainNode("Reshape"),
                FwdChainNode("Transpose"),
                FwdChainNode("Softmax"),
                FwdChainNode("Transpose"),
                FwdChainNode("Conv"),
            ],
        ]
        return get_node_from_possible_chains(self._graph, self, chains)

    def get_yolov8_reg_length(self):
        conv = self.get_possible_yolov8_postprocess()
        if conv:
            conv_values = self._graph.values_by_vertex_name.get(conv.name)
            if conv_values:
                return len(conv_values["kernel"].flatten())
        return None

    def get_conv3d_padding_type(self, pads, kernel_d, kernel_h, kernel_w, stride_d, stride_h, stride_w):
        # Assuming all(dilations == 1)
        d, h, w = self.get_input_shapes(convert_to_nhwc=False)[0][-3:]
        d_out, h_out, w_out = self.get_output_shapes(convert_to_nhwc=False)[0][-3:]

        # verify the padding is same
        # padding same is when h_out = ceil(h / stride_h) and w_out = ceil(w / stride_w)
        if not (
            (int(np.ceil(h / stride_h)) == h_out)
            and (int(np.ceil(w / stride_w)) == w_out)
            and (int(np.ceil(d / stride_d)) == d_out)
        ):
            raise UnsupportedConv3DError(f"{self.name} is 3D convolution with unsupported padding: {pads}")

        def calculate_pad_total_size(input_size, kernel_size, strides):
            adjusted_strides = strides if input_size % strides == 0 else (input_size % strides)
            total_pad_size = max(kernel_size - adjusted_strides, 0)
            pad_beg_size = total_pad_size // 2
            pad_end_size = total_pad_size - pad_beg_size
            return pad_beg_size, pad_end_size

        # checks if the padding is padding same or padding same tensorflow
        front, back = calculate_pad_total_size(d, kernel_d, stride_d)
        top, bottom = calculate_pad_total_size(h, kernel_h, stride_h)
        left, right = calculate_pad_total_size(w, kernel_w, stride_w)

        # if sum of pads is grater than the total size of padding then onnx prefer to pad first
        # the top / left / front and then the bottom / right / back
        # so we need to swap the results of the previous function
        if sum(pads[:2]) > top + bottom:
            top, bottom = bottom, top
        if sum(pads[2:4]) > left + right:
            left, right = right, left
        if sum(pads[4:]) > front + back:
            front, back = back, front
        # in tensorflow padding same is when the top / left is floor mode of (total padding / 2) and the bottom / right
        # is ceil mode of total padding value / 2.
        # means if top / left / front are less than bottom / right / back then the padding is padding same tensorflow
        if top < bottom and left < right and front < back:
            return PaddingType.same_tensorflow

        if top >= bottom and left >= right and front >= back:
            return PaddingType.same
        return None

    def get_conv3d_info(self):
        consumed_vertices = []

        input_shapes = self.get_input_shapes(convert_to_nhwc=False)
        if not input_shapes:
            raise UnsupportedConv3DError(f"Can't find input shapes for 3D convolution {self.name}")

        input_shape = input_shapes[0]
        input_channels = input_shape[1]
        input_disparity = input_shape[2]
        groups = self.get_groups()

        output_shapes = self.get_output_shapes()
        if not output_shapes:
            raise UnsupportedConv3DError(f"Can't find output shapes for 3D convolution {self.name}")

        vertex_kernel, consumed_kernel_nodes = self.get_kernel(is_conv2d=False)
        consumed_vertices.extend(consumed_kernel_nodes)
        if len(vertex_kernel.shape) == 4:
            vertex_kernel = np.expand_dims(vertex_kernel, axis=-2)
        kernel_d = vertex_kernel.shape[-3]
        # [f_out, f_in, k_d, k_h, k_w] -> [k_w, k_h, f_in * k_d, f_out]
        kernel = np.transpose(
            np.squeeze(np.concatenate(np.dsplit(vertex_kernel, kernel_d), axis=1), axis=2),
            [2, 3, 1, 0],
        )

        bias, consumed_bias_nodes = self.get_bias()
        consumed_vertices.extend(consumed_bias_nodes)

        _, pads_val, _, consumed_padding_nodes = self.get_vertex_padding()
        padding = PaddingType.valid if all(pad == 0 for pad in pads_val) else TemporaryPaddingType.conv3d
        consumed_vertices.extend(consumed_padding_nodes)

        strides = self.get_attribute_by_name("strides")[0].ints

        dilations_attr = self.get_attribute_by_name("dilations")
        if dilations_attr and any(dilation != 1 for dilation in dilations_attr[0].ints):
            raise UnsupportedConv3DError(
                f"{self.name} is 3D convolution with non-trivial dilations which is unsupported",
            )
        dilations = [1, 1, 1, 1]

        if len(strides) == 2:
            stride_d, stride_w = strides
            stride_h = 1
        else:
            stride_d, stride_h, stride_w = strides
        strides = [1, stride_h, stride_w, stride_d * input_channels]
        info = Conv3DInfo(kernel, bias, padding, pads_val, strides, dilations, output_shapes, input_disparity, groups)

        return info, consumed_vertices

    def is_concat_of_pos_embeds(self):
        if self.op == "Concat" and self.get_axis() == 1:
            preds = list(self.graph.predecessors(self))
            if len(preds) == 2 and all(pred.op == "Add" for pred in preds):
                for add in preds:
                    add_preds = list(self.graph.predecessors(add))
                    if (
                        len(add_preds) != 1
                        or add_preds[0].op != "Transpose"
                        or add_preds[0].get_transpose_perm() != [0, 2, 1]
                    ):
                        return False
                return True
        return False

    def is_concat_of_transpose_width_features(self):
        if self.op == "Concat" and self.get_axis() == 3:
            preds = list(self.graph.predecessors(self))
            if len(preds) == 2 and all(pred.is_width_features_transpose() for pred in preds):
                return True
        return False

    def is_spatial_flatten_complex(self):
        if self.op != "Unsqueeze":
            return False, []

        consumed_vertices = []
        in_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        axes = self.get_axes_information()
        chain = get_all_nodes_in_chain(self.graph, self, [FwdChainNode("Transpose"), FwdChainNode("Reshape")])
        if len(in_shape) == 3 and chain and axes == [-1]:
            consumed_vertices.extend(chain)
            perm = chain[0].get_transpose_perm()
            chain_out_shape, reshape_consumed = chain[1].get_reshape_shapes()
            consumed_vertices.extend(reshape_consumed)
            if (
                perm == [0, 3, 2, 1]
                and len(chain_out_shape) == 4
                and chain_out_shape[1] == in_shape[2]
                and chain_out_shape[2] * chain_out_shape[3] == in_shape[1]
            ):
                return True, consumed_vertices
        return False, []

    def get_log_softmax_axis(self):
        if not self.output_format:
            raise UnsupportedLogSoftmaxLayerError(f"Unable to find logsoftmax axis at {self.name}.")

        onnx_axis = self.get_axis()
        dim = self.output_format[onnx_axis]
        dim_to_hailo_axis = {Dims.CHANNELS: 3, Dims.WIDTH: 2, Dims.HEIGHT: 1}
        if dim not in dim_to_hailo_axis:
            raise UnsupportedLogSoftmaxLayerError(f"{self.op} is not supported for {onnx_axis} axis at {self.name}.")
        return dim_to_hailo_axis[dim]

    def is_instance_normalization_reshape(self):
        if self.op != "Reshape":
            return False

        second_reshape = look_for_node(
            self._graph,
            self,
            [FwdChainNode(op="InstanceNormalization"), FwdChainNode(op="Reshape")],
        )

        if second_reshape is None:
            return False

        first_reshape_in_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        second_reshape_out_shape = second_reshape.get_output_shapes(
            convert_to_nhwc=False,
        )[0]

        return first_reshape_in_shape == second_reshape_out_shape

    def is_group_norm_reshape(self):
        if self.op != "Reshape":
            return False, None

        second_reshape = look_for_node(
            self._graph,
            self,
            [FwdChainNode(op="InstanceNormalization"), FwdChainNode(op="Reshape")],
        )

        if second_reshape is None:
            return False, None

        first_reshape_in_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        first_reshape_out_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
        second_reshape_out_shape = second_reshape.get_output_shapes(
            convert_to_nhwc=False,
        )[0]

        if len(first_reshape_out_shape) != 3:
            return False, None

        if first_reshape_in_shape[1] != second_reshape_out_shape[1]:
            return False, None

        groups = first_reshape_in_shape[1] / first_reshape_out_shape[1]
        groups_cond = groups.is_integer() and groups >= 1
        if (
            self.input_format == [Dims.BATCH, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH]
            and np.prod(first_reshape_in_shape[2:]) == second_reshape_out_shape[-1]
        ):
            return groups_cond, [Dims.BATCH, Dims.CHANNELS, Dims.WIDTH]

        return groups_cond, None

    def get_group_norm_info(self):
        inst_norm, second_reshape = get_all_nodes_in_chain(
            self._graph,
            self,
            [FwdChainNode(op="InstanceNormalization"), FwdChainNode(op="Reshape")],
        )
        consumed_vertices = [inst_norm]
        if not second_reshape.is_spatial_flatten_reshape_after_group_norm()[0]:
            consumed_vertices.append(second_reshape)

        group_norm_info, additional_consumed_vertices = inst_norm.get_instance_normalization_info()
        groups = inst_norm.get_output_shapes(convert_to_nhwc=False)[0][1]
        group_norm_info["groups"] = groups
        group_norm_info["axes"] = [1, 2, 3]
        rms_norm = False  # Currently we don't identify group RMS normalization

        return group_norm_info, rms_norm, consumed_vertices + additional_consumed_vertices

    def is_layer_norm(self, ignore_axis=False):
        is_rms_norm = False
        first_reduce_mean = None
        if self.is_square():
            rms_fwd_chains = [
                [
                    FwdChainNode("ReduceMean"),
                    FwdChainNode("Add"),
                    FwdChainNode("Sqrt"),
                    FwdChainNode("Div"),
                    FwdChainNode("Mul"),
                ],
                [
                    FwdChainNode("ReduceMean"),
                    FwdChainNode("Add"),
                    FwdChainNode("Sqrt"),
                    FwdChainNode("Reciprocal"),
                    FwdChainNode("Mul"),
                ],
            ]
            nodes_dict = self._get_nodes_dict_in_possible_chains(rms_fwd_chains)
            nodes_dict["Pow"] = self
            is_rms_norm = True
        elif self.op in MUL_OPS:
            rms_bwd_chains = [
                [
                    BwdChainNode("Div"),
                    BwdChainNode("Sqrt"),
                    BwdChainNode("Add"),
                    BwdChainNode("ReduceMean"),
                    BwdChainNode("Pow"),
                ],
                [
                    BwdChainNode("Reciprocal"),
                    BwdChainNode("Sqrt"),
                    BwdChainNode("Add"),
                    BwdChainNode("ReduceMean"),
                    BwdChainNode("Pow"),
                ],
            ]
            nodes_dict = self._get_nodes_dict_in_possible_chains(rms_bwd_chains)
            nodes_dict["Mul"] = self
            is_rms_norm = True
        elif self.op in ["ReduceMean"] and (ignore_axis or self.is_reduce_mean_layer()):
            first_reduce_mean = self
            # the reduce mean is over the channels axis, might be layer norm operation
            layer_norm_chains = [
                [
                    FwdChainNode("Sub"),
                    FwdChainNode("Pow"),
                    FwdChainNode("ReduceMean"),
                    FwdChainNode("Add"),
                    FwdChainNode("Sqrt"),
                    FwdChainNode("Div"),
                ],
                [
                    FwdChainNode("Sub"),
                    FwdChainNode("Mul"),
                    FwdChainNode("ReduceMean"),
                    FwdChainNode("Add"),
                    FwdChainNode("Sqrt"),
                    FwdChainNode("Div"),
                ],
            ]
            nodes_dict = self._get_nodes_dict_in_possible_chains(layer_norm_chains)
        elif self.op in SUB_OPS:
            first_reduce_mean = look_for_node(self._graph, self, [BwdChainNode("ReduceMean")])
            if not first_reduce_mean or not (ignore_axis or first_reduce_mean.is_reduce_mean_layer()):
                return False
            layer_norm_chain = [
                FwdChainNode("Pow"),
                FwdChainNode("ReduceMean"),
                FwdChainNode("Add"),
                FwdChainNode("Sqrt"),
                FwdChainNode("Div"),
            ]
            nodes_dict = self._get_nodes_dict_in_chain(layer_norm_chain)
            nodes_dict["Sub"] = self
        else:
            return False

        if len(nodes_dict) != 6:
            return False

        power = nodes_dict.get("Pow") or nodes_dict.get("Mul")
        sub, reduce_mean = nodes_dict.get("Sub"), nodes_dict.get("ReduceMean")
        add, div, reciprocal = nodes_dict.get("Add"), nodes_dict.get("Div"), nodes_dict.get("Reciprocal")
        ew_mul = nodes_dict.get("Mul")

        if any(node is None for node in [power, reduce_mean, add]) or all(node is None for node in [div, reciprocal]):
            return False

        if is_rms_norm and (
            (div and not div.is_inv_pos_activation()[0] and not reciprocal)
            or (ew_mul and not self._has_common_stem(ew_mul, power))
        ):
            return False

        if not (
            is_rms_norm
            or (
                self._has_common_stem(first_reduce_mean, sub)
                and (
                    (div and self._has_common_stem(div, power))
                    or (reciprocal and self._has_common_stem(reciprocal, power))
                )
            )
        ):
            return False

        return (
            power.is_square()
            and (add.is_normalization() and add.get_normalization_input_raw_values()[0].shape == ())
            and (ignore_axis or reduce_mean.is_reduce_mean_layer())
        )

    def is_l2_normalization(self):
        if self.op != "ReduceL2":
            return False
        # checks the axis is channels using output format
        pred = next(iter(self.graph.predecessors(self)))
        axes = self.get_attribute_by_name("axes")
        axes = axes[0].ints if axes else self.get_initializer_or_constant_value(MIN_INPUT_ORDER)
        if pred.output_format[axes[0]] != Dims.CHANNELS:
            return False

        # if keepdims is off, then the current node should be followed by an unsqueeze node
        chain = (
            [FwdChainNode(op="Unsqueeze"), FwdChainNode(op="Div")]
            if self.get_attribute_by_name("keepdims")[0].i == 0
            else [FwdChainNode(op="Div")]
        )

        div_node = look_for_node(self._graph, self, chain)

        return div_node is not None

    def validate_scatter_nd_info(self):
        if len(self._info.input) != 3:
            raise UnsupportedScatterNDError(
                f"ScatterND info missing at {self.name}, expected 3 inputs but {len(self._info.input)} were found",
            )
        data = self._info.input[0]
        indices = self._info.input[1]
        updates = self._info.input[2]

        if not data or not indices or not updates:
            missing_elements = {
                k: v for k, v in {"data": data, "indices": indices, "updates": updates}.items() if not v
            }
            missing_info = ", ".join(missing_elements.keys())
            raise UnsupportedScatterNDError(
                f"ScatterND info missing at {self.name}: {missing_info}",
            )

    def is_ew_op_with_multi_non_const_inputs(self):
        non_const_preds = [x for x in self.graph.predecessors(self) if not x.is_const()]
        return self.op in EW_OPS and len(non_const_preds) > 1

    def is_optional_graph_start_node(self):
        return (
            not self.is_const()
            and not self.is_shape_op()
            and not self.is_ew_op_with_multi_non_const_inputs()
            and self.op not in MUL_OPS + OUTPUT_OPS + LOGICAL_OPS
        )

    def _get_nodes_dict_in_chain(self, chain):
        nodes = get_all_nodes_in_chain(self.graph, self, chain)
        if nodes is None:
            return {}
        return {node.op: node for node in nodes}

    def _get_nodes_dict_in_possible_chains(self, chains):
        nodes = get_all_nodes_from_possible_chains(self.graph, self, chains)
        if nodes is not None:
            return {node.op: node for node in nodes}
        return {}

    def _has_common_stem(self, a, b):
        preds_a = list(self._graph.predecessors(a))
        preds_b = list(self._graph.predecessors(b))
        return any(x in preds_a for x in preds_b)

    def is_transpose_before_conv3d(self):
        if self.op != "Transpose" or self.input_format is None:
            return False

        conv3d_succs = all(succ.is_conv3d() for succ in self._graph.successors(self))
        perm = self.get_transpose_perm()
        output_format = [self.input_format[i] for i in perm]
        return conv3d_succs and output_format == [Dims.BATCH, Dims.CHANNELS, Dims.GROUPS, Dims.HEIGHT, Dims.WIDTH]

    def is_decomposed_l2_norm(self):
        if self.op != "Abs":
            return False

        chain = [FwdChainNode(op) for op in ["Pow", "ReduceSum", "Pow", "Clip", "Expand", "Div"]]
        div = look_for_node(self._graph, self, chain)
        if not div:
            return False

        pred = next(iter(self._graph.predecessors(self)))
        # division of input with the norm
        return pred in self._graph.predecessors(div)

    def is_features_to_stack_with_flat_height_reshape(self, input_format):
        if self.op != "Reshape" or self.input_format is None:
            return False

        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        output_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
        if len(input_shape) != 3 or len(output_shape) != 5:
            return False
        width_idx = input_format.index(Dims.WIDTH)
        channels_idx = input_format.index(Dims.CHANNELS)
        w_in = input_shape[width_idx]
        w_out = output_shape[width_idx]
        f_in = input_shape[channels_idx]
        f_out = output_shape[-1]
        h_out = output_shape[-2]
        groups = output_shape[2]
        return w_in == w_out and f_in == f_out * groups * h_out and h_out == 1

    def is_features_to_stack_with_flat_groups_reshape(self, input_format):
        if (
            self.op != "Reshape"
            or self.input_format is None
            or self.input_format != [Dims.WIDTH, Dims.BATCH, Dims.CHANNELS]
        ):
            return False

        input_shape = self.get_input_shapes(convert_to_nhwc=False)[0]
        output_shape = self.get_output_shapes(convert_to_nhwc=False)[0]
        if len(input_shape) != 3 or len(output_shape) != 4:
            return False

        width_idx = input_format.index(Dims.WIDTH)
        channels_idx = input_format.index(Dims.CHANNELS)
        w_in = input_shape[width_idx]
        w_out = output_shape[width_idx]
        f_in = input_shape[channels_idx]
        f_out = output_shape[-1]
        stack = output_shape[1]
        groups = output_shape[2]

        return w_in == w_out and f_in == f_out * groups * stack and stack == 1

    def is_reshape_for_gather_features_slice(self):
        if self.op != "Reshape":
            return False

        split_branch = [BwdChainNode("Gather"), BwdChainNode("Transpose")]
        consumed_vertices = get_all_nodes_in_chain(self.graph, self, split_branch)
        if not consumed_vertices:
            return False

        return consumed_vertices[-1].is_transpose_for_gather_features_slice()

    def is_transpose_for_gather_features_slice(self):
        if self.op != "Transpose":
            return False

        succs = list(self._graph.successors(self))
        if not all(succ.op == "Gather" for succ in succs) or len(succs) != 3:  # 3 branches (KQV)
            return False

        gather_indices = [succ.get_gather_index()[0] for succ in succs]
        if any(len(gather_indices) != 1 for gather_indices in gather_indices):  # validate all indices are single value
            return False

        gather_indices = [gather_index[0] for gather_index in gather_indices]
        if len(gather_indices) != len(set(gather_indices)):  # validate all axes indices from one another
            return False

        gather_dims = [succ.input_format[succ.get_axis()] for succ, axes in zip(succs, gather_indices)]

        if any(dim != Dims.STACK for dim in gather_dims):  # validate all axes are stack
            return False

        split_branch = [FwdChainNode("Reshape"), FwdChainNode("Transpose")]
        return all(get_all_nodes_in_chain(self.graph, succ, split_branch) for succ in succs)


class ONNXGraph(NNGraph):
    def __init__(self, graph, values, net_input, tensor_shapes, output_shapes, opset_version, net_input_format):
        super().__init__(graph, values)
        self._vertices_by_name = {}
        self._tensor_shapes_by_name = {}
        self._init_graph_input(net_input)
        self._init_vertices()
        self._init_values_by_vertices()
        self._init_vertices_connections()
        self._init_tensor_shapes_by_vertices(tensor_shapes)
        self._output_shapes = output_shapes
        self._opset_version = opset_version
        self._net_input_format = net_input_format if net_input_format else {}
        self._null_transposes = []

    def _init_graph_input(self, net_input):
        non_const_inputs = [x for x in self._raw_proto if x.op_type not in CONST_OPS]
        initializer_names = [x.name for x in self._values]
        non_initializer_inputs = [x for x in net_input if x.name not in initializer_names]
        self._net_input = []
        for inp in non_initializer_inputs:
            if any(inp.name in x.input for x in non_const_inputs):
                self._net_input.append(ONNXGraphNode(inp, self, is_input_vertex=True))
                self.add_node(self._net_input[-1])
        if not self._net_input or len(non_initializer_inputs) != len(self._net_input):
            raise UnsupportedModelError(
                f"Couldn't find inputs from ONNX proto. Number of expected inputs: {len(non_initializer_inputs)}, "
                f"Inputs found: {len(self._net_input)}",
            )

    def _init_vertices(self):
        for node in self._raw_proto:
            vertex = ONNXGraphNode(node, self)
            self.add_node(vertex)
            self.add_vertex_by_name(vertex)

    def _init_vertices_connections(self):
        self._vertices_by_inp_key = {net_input.name: net_input for net_input in self._net_input}
        for vi in self._vertices_by_name.values():
            for net_input in self._net_input:
                if net_input.name in vi.input:
                    self.add_edge(net_input, vi)

            vi_in_net_input = any(vi == net_input for net_input in self._net_input)
            for vj in self._vertices_by_name.values():
                for inp in vj.input:
                    if inp in vi.output or (vi_in_net_input and vi.name in vj.input):
                        self.add_edge(vi, vj)
                        self._vertices_by_inp_key[inp] = vi

        for vertex in self._vertices_by_name.values():
            input_indices = []
            for inp in vertex.input:
                if inp in self._vertices_by_inp_key:
                    input_vertex = self._vertices_by_inp_key[inp]
                    if inp != input_vertex.name or len(input_vertex.output) > 1:
                        input_indices.append(f"{input_vertex.name}{VERTEX_NAME_SEPARATOR}{inp}")
                    else:
                        input_indices.append(input_vertex.name)
            vertex.input = input_indices

    def _init_values_by_vertices(self):
        values_by_name = {x.name: numpy_helper.to_array(x) for x in self._values}
        self._values_by_vertex_name = {}
        for name, vertex in self._vertices_by_name.items():
            self._values_by_vertex_name[name] = {}
            for i, inp in enumerate(vertex.input):
                if inp not in values_by_name:
                    continue
                key = self._get_initializer_key_from_op_index(vertex, i, inp)
                self._values_by_vertex_name[name].update({key: values_by_name[inp]})
        del values_by_name

    def _init_tensor_shapes_by_vertices(self, tensor_shapes):
        for net_input in self._net_input:
            self._tensor_shapes_by_name[net_input.name] = net_input.get_net_input_shape()
        tensors_by_output_index = {x.name: x.type.tensor_type.shape.dim for x in tensor_shapes}
        for vertex_name, vertex in self._vertices_by_name.items():
            self._tensor_shapes_by_name[vertex_name] = []
            for idx in vertex.output:
                if idx in tensors_by_output_index:
                    shapes = [int(x.dim_value) for x in tensors_by_output_index[idx]]
                    self._tensor_shapes_by_name[vertex_name].extend(shapes)

    @staticmethod
    def _get_initializer_key_from_op_index(vertex, index, inp):
        if vertex.op in OPS_WITH_WEIGHTS:
            if vertex.op == "MatMul":
                return "kernel"

            return OPS_WITH_WEIGHTS_PARAMS_ORDER[index]
        return inp

    @property
    def vertices_by_name(self):
        return self._vertices_by_name

    @property
    def values_by_vertex_name(self):
        return self._values_by_vertex_name

    @property
    def vertices_by_inp_key(self):
        return self._vertices_by_inp_key

    @property
    def net_input(self):
        return self._net_input

    @property
    def net_input_format(self):
        return self._net_input_format

    @property
    def tensor_shapes_by_vertex_name(self):
        return self._tensor_shapes_by_name

    @property
    def output_shapes(self):
        return self._output_shapes

    @property
    def opset_version(self):
        return self._opset_version

    @property
    def null_transposes(self):
        return self._null_transposes
