import numpy as np
from tflite import (
    ActivationFunctionType,
    AddOptions,
    ConcatenationOptions,
    Conv2DOptions,
    DepthToSpaceOptions,
    DepthwiseConv2DOptions,
    DivOptions,
    FullyConnectedOptions,
    L2NormOptions,
    LeakyReluOptions,
    MulOptions,
    PackOptions,
    Padding,
    Pool2DOptions,
    ReducerOptions,
    ResizeBilinearOptions,
    ResizeNearestNeighborOptions,
    SpaceToDepthOptions,
    StridedSliceOptions,
    SubOptions,
)
from tflite.utils import BUILTIN_OPCODE2NAME

from hailo_sdk_client.model_translator.edge_nn_translator import INPUT_OP
from hailo_sdk_client.model_translator.exceptions import (
    CantFindBiasedDelatError,
    CantFindGraphStartError,
    CantFindSwishBetaError,
    UnsupportedActivationLayerError,
    UnsupportedConcatLayerError,
    UnsupportedConvLayerError,
    UnsupportedEWLayerError,
    UnsupportedFeatureSplitterError,
    UnsupportedFeatureSplitterLayerError,
    UnsupportedNormalizationLayerError,
    UnsupportedPaddingError,
    UnsupportedQuantizedWeightsError,
    UnsupportedReduceMeanLayerError,
    UnsupportedSliceLayerError,
    UnsupportedSquareLayerError,
)
from hailo_sdk_client.model_translator.graph_lookup import (
    BwdChainNode,
    FwdChainNode,
    get_all_nodes_from_possible_chains,
    get_all_nodes_in_chain,
    get_node_from_possible_chains,
    look_and_validate,
    look_for_node,
)
from hailo_sdk_client.model_translator.nn_graph import NNGraph, NNGraphNode
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType, PaddingType, ResizeBilinearPixelsMode, ResizeMethod

OUTPUT_OPS = ["output"]
POOL_OPS = ["AVERAGE_POOL_2D", "MAX_POOL_2D", "MEAN"]
CONCAT_OPS = ["CONCATENATION"]
CONV2D_OPS = ["CONV_2D", "DEPTHWISE_CONV_2D", "TRANSPOSE_CONV"]
DENSE_OPS = ["FULLY_CONNECTED"]
L2_OPS = ["L2_NORMALIZATION"]
ACTIVATION_OPS = [
    "RELU",
    "RELU6",
    "TANH",
    "EXP",
    "PRELU",
    "LESS",
    "SQRT",
    "LEAKY_RELU",
    "ABS",
    "ELU",
    "HARD_SWISH",
    "LOGISTIC",
    "LOG",
    "GREATER",
    "GELU",
    "RSQRT",
]
DEPTH_TO_SPACE_OPS = ["DEPTH_TO_SPACE"]
SPACE_TO_DEPTH_OPS = ["SPACE_TO_DEPTH"]
SHUFFLE_OPS = ["RESHAPE", "TRANSPOSE"]
RESIZE_OPS = ["RESIZE_BILINEAR", "RESIZE_NEAREST_NEIGHBOR"]
SOFTMAX_OPS = ["SOFTMAX"]
ARGMAX_OPS = ["ARG_MAX"]
PAD_OPS = ["PAD", "PADV2"]
SHAPE_OPS = ["SHAPE", "GATHER", "EXPAND_DIMS", "TILE", "PACK", "UNPACK"]
SLICE_OPS = ["SLICE", "STRIDED_SLICE"]
SPLIT_OPS = ["SPLIT", "SPLIT_V"]
POW_OPS = ["POW", "SQUARE"]
MATH_OPS = ["CAST", "FLOOR"]
REDUCE_MAX_OPS = ["REDUCE_MAX"]
REDUCE_MIN_OPS = ["REDUCE_MIN"]
SUM_OPS = ["SUM"]
DEQUANTIZE_OPS = ["DEQUANTIZE"]
ADD_OPS = ["ADD"]
ADD_N_OPS = ["ADD_N"]
SUB_OPS = ["SUB", "NEG"]
MUL_OPS = ["MUL"]
DIV_OPS = ["DIV"]
MIN_OPS = ["MINIMUM"]
MAX_OPS = ["MAXIMUM"]
EW_OPS = ADD_OPS + SUB_OPS + MUL_OPS + DIV_OPS + MAX_OPS + MIN_OPS
EQUAL_OPS = ["EQUAL"]
LOGICAL_OPS = ["Where", *EQUAL_OPS]
TILE_OPS = ["TILE"]
PACK_OPS = ["PACK"]
SKIP_OPS = ["SHAPE", "GATHER"]
PRE_LAYER_OPS = ["RESHAPE"]
SUPPORTED_FUSED_ACTIVATIONS = ["RELU", "RELU6", "TANH"]
OPTIONAL_NULL_OPS = [*MATH_OPS, "TRANSPOSE", "UNPACK", "RESHAPE", "POW", "SQUEEZE", "EXPAND_DIMS", "TILE"]
CUSTOM_SIGN_OPS = ["Sign", "SIGN", "FLEX_SIGN"]
NORMALIZATION_OPS = ADD_OPS + SUB_OPS + MUL_OPS + DIV_OPS + L2_OPS
SUPPORTED_OPS_UNION = [
    INPUT_OP,
    *POOL_OPS,
    *CONCAT_OPS,
    *CONV2D_OPS,
    *DENSE_OPS,
    *ACTIVATION_OPS,
    *RESIZE_OPS,
    *SHUFFLE_OPS,
    *SOFTMAX_OPS,
    *ARGMAX_OPS,
    *PAD_OPS,
    *SHAPE_OPS,
    *SLICE_OPS,
    *SPLIT_OPS,
    *POW_OPS,
    *MATH_OPS,
    *REDUCE_MAX_OPS,
    *NORMALIZATION_OPS,
    *MIN_OPS,
    *MAX_OPS,
    *SUM_OPS,
    *SKIP_OPS,
    *DEQUANTIZE_OPS,
    *DEPTH_TO_SPACE_OPS,
    *SPACE_TO_DEPTH_OPS,
    *ADD_N_OPS,
    *CUSTOM_SIGN_OPS,
    *EQUAL_OPS,
    *REDUCE_MIN_OPS,
    *OPTIONAL_NULL_OPS,
]

CONV_INPUT_ORDER = ["X", "W", "B"]
DEPTHWISE_CONV_INPUT_ORDER = ["X", "W", "B"]
DECONV_INPUT_ORDER = ["OUTPUT_SHAPE", "W", "X", "B"]
DENSE_INPUT_ORDER = ["X", "W", "B"]
SPLIT_INPUT_ORDER = ["AXIS", "X"]
SPLIT_V_INPUT_ORDER = ["X", "SIZE", "AXIS"]
POW_INPUT_ORDER = ["X", "Y"]
RESIZE_INPUT_ORDER = ["X", "SIZES"]
REDUCE_INPUT_ORDER = ["X", "AXIS"]
TRANSPOSE_INPUT_ORDER = ["X", "PERM"]
STRIDED_SLICE_INPUT_ORDER = ["X", "BEGIN", "END", "STRIDES"]
SLICE_INPUT_ORDER = ["X", "BEGIN", "SIZE"]
MINIMUM_INPUT_ORDER = ["X", "CONST"]
MAXIMUM_INPUT_ORDER = ["X", "CONST"]

ALTERNATIVE_CUSTOM_OPCODES = {158: "SIGN"}


class TFLiteGraphNode(NNGraphNode):
    def __init__(self, node_proto, graph, is_input_vertex=False):
        super().__init__(node_proto, graph)
        self._attrs_dict = {}
        if is_input_vertex:
            self.op = INPUT_OP
            self.name = node_proto.Name().decode("utf-8")
            tensor_shapes = [node_proto.ShapeAsNumpy().tolist()]
            self.input_shapes = tensor_shapes
            self.output_shapes = tensor_shapes
            self.output_tensors_indices = []
        else:
            self.name = graph._model_tensors_names[node_proto.OutputsAsNumpy()[0]]
            opcode = graph._raw_proto.OperatorCodes(node_proto.OpcodeIndex()).BuiltinCode()
            self.op = "CUSTOM"
            if opcode in BUILTIN_OPCODE2NAME:
                self.op = BUILTIN_OPCODE2NAME[opcode]
            elif opcode in ALTERNATIVE_CUSTOM_OPCODES:
                self.op = ALTERNATIVE_CUSTOM_OPCODES[opcode]

            # handle (known) custom ops. currently support SIGN only.
            if self.op == "CUSTOM":
                custom_options = str(node_proto.CustomOptionsAsNumpy().tobytes())
                if any(x in custom_options for x in CUSTOM_SIGN_OPS):
                    self.op = "SIGN"

            # -1 means input not exist for optional inputs in MLIR
            self.input = [graph._model_tensors_names[i] for i in node_proto.InputsAsNumpy() if i > -1]
            self.output = [graph._model_tensors_names[i] for i in node_proto.OutputsAsNumpy() if i > -1]

            input_tensors = [graph._model_tensors[i] for i in node_proto.InputsAsNumpy() if i > -1]
            out_tensors = [graph._model_tensors[out_idx] for out_idx in node_proto.OutputsAsNumpy() if out_idx > -1]
            input_tensor_shapes = [tensor.ShapeAsNumpy() for tensor in input_tensors]
            output_tensor_shapes = [tensor.ShapeAsNumpy() for tensor in out_tensors]
            self.input_shapes = [[1] if isinstance(x, int) else x.tolist() for x in input_tensor_shapes]
            self.output_shapes = [[1] if isinstance(x, int) else x.tolist() for x in output_tensor_shapes]
            self.output_tensors_indices = [graph._model_tensors.index(tensor) for tensor in out_tensors]

    def get_original_info_to_json(self):
        return str(self._attrs_dict)

    def get_vertex_successors_io_indices(self):
        succs_by_inp_idx = {}
        for succ in list(self._graph.successors(self)):
            input_index = next(x for x in succ.input if self.name in x)
            input_index = f"{input_index}:0" if ":" not in input_index else input_index
            if input_index in succs_by_inp_idx:
                succs_by_inp_idx[input_index].append(succ)
            else:
                succs_by_inp_idx[input_index] = [succ]

        def sort_func(x):
            outputs = [f"{out}:0" if ":" not in out else out for out in self.output]
            return outputs.index(x)

        return {i: succs_by_inp_idx[x] for i, x in enumerate(sorted(succs_by_inp_idx.keys(), key=sort_func))}

    def get_conv_info(self):
        builtin_ops = self._info.BuiltinOptions()

        if self.op in ["CONV_2D", "TRANSPOSE_CONV"]:
            ops = Conv2DOptions()
            ops.Init(builtin_ops.Bytes, builtin_ops.Pos)
            input_order = CONV_INPUT_ORDER if self.op == "CONV_2D" else DECONV_INPUT_ORDER
        elif self.op == "DEPTHWISE_CONV_2D":
            ops = DepthwiseConv2DOptions()
            ops.Init(builtin_ops.Bytes, builtin_ops.Pos)
            self._attrs_dict["depth_multiplier"] = ops.DepthMultiplier()
            input_order = DEPTHWISE_CONV_INPUT_ORDER
        else:
            raise UnsupportedConvLayerError(f"Unexpected convolutional layer at {self.name}, op={self.op}")

        self._attrs_dict["strides"] = [1, int(ops.StrideH()), int(ops.StrideW()), 1]
        self._attrs_dict["fused_activation"] = self.get_fused_activation_op(ops.FusedActivationFunction())

        if self.op == "TRANSPOSE_CONV":
            self._attrs_dict["dilations"] = [1, 1, 1, 1]
            self._attrs_dict["padding"] = PaddingType.deconv
        else:
            self._attrs_dict["dilations"] = [1, int(ops.DilationHFactor()), int(ops.DilationWFactor()), 1]
            if ops.Padding() == Padding.VALID:
                self._attrs_dict["padding"] = PaddingType.valid
            elif ops.Padding() == Padding.SAME:
                self._attrs_dict["padding"] = PaddingType.same_tensorflow

        input_tensors = self.graph.values_by_vertex_name[self.name]
        kernel = input_tensors.get(self.input[input_order.index("W")], None)

        bias = None
        if len(self.input) > input_order.index("B"):
            # the input contains bias
            bias = input_tensors.get(self.input[input_order.index("B")], None)

        # fallback, weights are behind dequantize layer casting to float dtype (permitted earlier)
        dequantize_preds = [x for x in self.graph.predecessors(self) if x.op == "DEQUANTIZE"]
        if len(dequantize_preds) > 0:
            kernel_name = self.input[input_order.index("W")]
            kernel_dequantize = self.graph.values_by_vertex_name[kernel_name]
            kernel = kernel_dequantize.get(kernel_name.partition("_dequantize")[0], None)

            if len(dequantize_preds) > 1:
                bias_name = self.input[input_order.index("B")]
                bias_dequantize = self.graph.values_by_vertex_name[bias_name]
                bias = bias_dequantize.get(bias_name.partition("_dequantize")[0], None)

        self.verify_quantized_weights([kernel, bias])
        return self._attrs_dict, kernel, bias

    def get_dense_info(self):
        builtin_ops = self._info.BuiltinOptions()
        ops = FullyConnectedOptions()
        ops.Init(builtin_ops.Bytes, builtin_ops.Pos)
        self._attrs_dict["fused_activation"] = self.get_fused_activation_op(ops.FusedActivationFunction())

        input_tensors = self.graph.values_by_vertex_name[self.name]
        kernel = input_tensors.get(self.input[DENSE_INPUT_ORDER.index("W")], None)
        bias = None
        if len(self.input) == len(DENSE_INPUT_ORDER):
            bias = input_tensors.get(self.input[DENSE_INPUT_ORDER.index("B")], None)

        # fallback, weights are behind dequantize layer casting to float dtype (permitted earlier)
        dequantize_preds = [x for x in self.graph.predecessors(self) if x.op == "DEQUANTIZE"]
        if len(dequantize_preds) > 0:
            kernel_name = self.input[DENSE_INPUT_ORDER.index("W")]
            kernel_dequantize = self.graph.values_by_vertex_name[kernel_name]
            kernel = kernel_dequantize.get(kernel_name.partition("_dequantize")[0], None)

            if len(dequantize_preds) > 1:
                bias_name = self.input[DENSE_INPUT_ORDER.index("B")]
                bias_dequantize = self.graph.values_by_vertex_name[bias_name]
                bias = bias_dequantize.get(bias_name.partition("_dequantize")[0], None)

        self.verify_quantized_weights([kernel, bias])
        return self._attrs_dict, kernel, bias

    @staticmethod
    def verify_quantized_weights(variables):
        int_dtypes = ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"]
        np_int_dtypes = [np.dtype(x) for x in int_dtypes]
        variables_dtypes = {var.dtype for var in variables if var is not None}
        int_dtypes_cond = any(var_dtype in np_int_dtypes for var_dtype in variables_dtypes)
        if int_dtypes_cond:
            error_msg = (
                f"Model weights are quantized to {list(variables_dtypes)}, which is currently unsupported. "
                "Please use a model with float32/16 only, mixed precision is not supported as well."
            )
            raise UnsupportedQuantizedWeightsError(error_msg)

        if any(var_dtype == np.float16 for var_dtype in variables_dtypes) and any(
            var_dtype == np.float32 for var_dtype in variables_dtypes
        ):
            error_msg = (
                "Model weights are mixed precision (float16/float32), which is currently unsupported. "
                "Please use a model with float32/16 only."
            )
            raise UnsupportedQuantizedWeightsError(error_msg)

    def get_fused_activation_op(self, act):
        if act is None or act == ActivationFunctionType.NONE:
            return "NONE"
        elif act == ActivationFunctionType.RELU:
            return "RELU"
        elif act == ActivationFunctionType.RELU6:
            return "RELU6"
        elif act == ActivationFunctionType.TANH:
            return "TANH"
        else:
            raise UnsupportedActivationLayerError(f"Unsupported fused activation type {act} found at layer {self.name}")

    def get_pooling_info(self):
        builtin_ops = self._info.BuiltinOptions()
        ops = Pool2DOptions()
        ops.Init(builtin_ops.Bytes, builtin_ops.Pos)

        self._attrs_dict["strides"] = [1, int(ops.StrideH()), int(ops.StrideW()), 1]
        self._attrs_dict["kernel_shape"] = [1, int(ops.FilterHeight()), int(ops.FilterWidth()), 1]
        self._attrs_dict["fused_activation"] = self.get_fused_activation_op(ops.FusedActivationFunction())

        if ops.Padding() == Padding.VALID:
            self._attrs_dict["padding"] = PaddingType.valid
        elif ops.Padding() == Padding.SAME:
            self._attrs_dict["padding"] = PaddingType.same_tensorflow

        return self._attrs_dict

    def get_avgpool_reduce_mean_info(self):
        _, axis = self.get_reduce_info()
        input_shape = self.input_shapes[0]
        if np.array_equal(axis, np.array(1)):
            kernel_shape = [1, input_shape[1], 1, 1]
            stride = [1, input_shape[1], 1, 1]
        elif np.array_equal(axis, np.array(2)):
            kernel_shape = [1, 1, input_shape[2], 1]
            stride = [1, 1, input_shape[2], 1]
        else:
            err_msg = (
                f"Reduce mean layer {self.name} has unsupported axis {axis} "
                "(must be over one spatial dimension only)."
            )
            raise UnsupportedReduceMeanLayerError(err_msg)
        return kernel_shape, stride

    def is_avgpool_reduce_mean(self):
        _, axis = self.get_reduce_info()
        input_shape = self.input_shapes[0]

        # axes is only width and height is not 1
        if np.array_equal(axis, np.array(2)) and input_shape[1] > 1:
            return True

        # axes is only height and width is not 1
        if np.array_equal(axis, np.array(1)) and input_shape[2] > 1:
            return True

        return False

    def validate_reduce_mean_info(self):
        keep_dims, axis = self.get_reduce_info()
        input_shape = self.input_shapes[0]

        # axes are just on width, but height is 1
        valid_height = np.array_equal(axis, np.array([2])) and input_shape[1] == 1
        # axes are just on height, but width is 1
        valid_width = np.array_equal(axis, np.array([1])) and input_shape[2] == 1

        unsupported_axis = not np.array_equal(axis, np.array([1, 2])) and not valid_height and not valid_width
        # keep_dims=False is ok if the correct axis are used
        unsupported_keep_dims = keep_dims is False and unsupported_axis

        if unsupported_axis or unsupported_keep_dims:
            err_msg = f"Reduce mean layer {self.name} has "
            if unsupported_axis:
                err_msg += f"unsupported axis {axis} (must be over spatial dimensions only), "
            if unsupported_keep_dims:
                if unsupported_axis:
                    err_msg += "and "
                err_msg += "unsupported keep_dims=False, "
            err_msg += "must be equivalent to global average pool."
            raise UnsupportedReduceMeanLayerError(err_msg)

    def get_prelu_slope(self):
        input_tensors = self.graph.values_by_vertex_name[self.name]
        prelu_slope = input_tensors.get(self.input[1], None)

        # fallback, weights are behind dequantize layer casting to float dtype (permitted earlier)
        dequantize_preds = [x for x in self.graph.predecessors(self) if x.op == "DEQUANTIZE"]
        if len(dequantize_preds) > 0:
            slope_name = self.input[1]
            slope_dequantize = self.graph.values_by_vertex_name[slope_name]
            prelu_slope = slope_dequantize.get(slope_name.partition("_dequantize")[0], None)

        if prelu_slope is not None:
            return prelu_slope.flatten()
        return None

    def get_leaky_alpha(self):
        leaky_ops = LeakyReluOptions()
        builtin_ops = self._info.BuiltinOptions()
        leaky_ops.Init(builtin_ops.Bytes, builtin_ops.Pos)
        self._attrs_dict["alpha"] = leaky_ops.Alpha()
        return self._attrs_dict["alpha"]

    def is_ew_mult(self):
        if self.op in MUL_OPS:
            return self.is_ew_op()
        return False

    def is_ew_add(self):
        if self.op in ADD_OPS:
            return self.is_ew_op() or self.is_ew_add_with_const_input()
        return False

    def is_ew_sub(self):
        if self.op == "SUB":
            return self.is_ew_op()
        return False

    def is_ew_div(self):
        if self.op in DIV_OPS:
            return self.is_ew_op()
        return False

    def is_ew_max(self):
        if self.op in MAX_OPS:
            return self.is_ew_op()
        return False

    def is_ew_min(self):
        if self.op in MIN_OPS:
            return self.is_ew_op()
        return False

    def is_ew_op(self):
        if self._has_dynamic_shape_const_input():
            return False

        return len(list(self.graph.predecessors(self))) == 2

    def is_ew_add_with_const_input(self):
        output_shapes = self.get_output_shapes()
        if self.op not in ADD_OPS or not output_shapes:
            return False

        input_tensors = self.graph.values_by_vertex_name[self.name]
        if input_tensors.get(self.input[1]) is None:
            return False

        values = input_tensors[self.input[1]]
        val_shape = list(np.shape(values))
        if len(val_shape) != 4:
            return False

        out_shape = output_shapes[0]
        return (
            val_shape[1:] == out_shape[1:]
            or (val_shape[1] == 1 and val_shape[2:] == out_shape[2:])
            or (val_shape[2] == 1 and val_shape[1] == out_shape[1] and val_shape[3] == out_shape[3])
        )

    def _has_dynamic_shape_const_input(self):
        shape = look_for_node(
            self._graph,
            self,
            [BwdChainNode(op="MUL"), BwdChainNode(op="FILL"), BwdChainNode(op="SHAPE")],
        )
        if shape is not None and self.has_common_stem(shape):
            return True

        return False

    def _get_dynamic_shape_const_input(self):
        mul_node, _, _ = get_all_nodes_in_chain(
            self._graph,
            self,
            [BwdChainNode(op="MUL"), BwdChainNode(op="FILL"), BwdChainNode(op="SHAPE")],
        )

        mul_idx = self.input.index(mul_node.name)
        values, _ = mul_node.get_normalization_node_values()
        return np.full(mul_node.output_shapes[0], values), mul_idx

    def is_mul_by_2_ew_add(self):
        if self.op not in ADD_OPS:
            return False
        preds = list(self._graph.predecessors(self))
        return (
            len(preds) == 1
            and len(self.input) == 2
            and self.input[0] == self.input[1]
            and preds[0].name == self.input[0]
        )

    def is_flatten_reshape(self):
        output_shape = self.output_shapes[0]
        input_shape = self.input_shapes[0]
        return (
            len(output_shape) == 2
            and len(input_shape) == 4
            and output_shape[1] == input_shape[1] * input_shape[2] * input_shape[3]
        )

    def is_transpose_flatten_reshape(self):
        if not self.is_null_transpose():
            return False

        next_node = look_for_node(self._graph, self, [FwdChainNode(op="RESHAPE")])
        return next_node is not None and next_node.is_flatten_reshape()

    def is_null_unpack(self):
        if self.op != "UNPACK":
            return False

        return look_for_node(self._graph, self, [FwdChainNode(op="EXPAND_DIMS")]) is not None

    def is_null_reshape(self):
        if self.op != "RESHAPE":
            return False

        if self._is_flat_to_frames_reshape_after_reduce_op():
            return True

        if self.is_rank4_to_rank3_reshape():
            return True

        return self.input_shapes[0] == self.output_shapes[0]

    def _is_flat_to_frames_reshape_after_reduce_op(self):
        if not self.is_flat_to_frames_reshape():
            return False

        pred = next(iter(self.graph.predecessors(self)))
        return pred.op in REDUCE_MAX_OPS + SUM_OPS + ["MEAN"]

    def is_rank4_to_rank3_reshape(self):
        expected_shape = self.output_shapes[0].copy()
        expected_shape.insert(1, 1)
        return self.input_shapes[0] == expected_shape

    def is_null_pow(self):
        input_tensors = self.graph.values_by_vertex_name[self.name]
        pow_value = input_tensors.get(self.input[POW_INPUT_ORDER.index("Y")], [])
        return pow_value == 1

    def is_null_tile(self):
        multipliers, _ = self.get_tile_multipliers_info()
        if multipliers is None:
            return False
        return all(multipliers == 1)

    def is_null_expand_dims(self):
        expected_shape = self.input_shapes[0].copy()
        expected_shape.insert(1, 1)
        return self.output_shapes[0] == expected_shape

    def is_null_operation(self):
        if self.op not in OPTIONAL_NULL_OPS:
            return False

        if self.op == "TRANSPOSE" and self.is_transpose_flatten_reshape():
            return True
        if self.op == "UNPACK" and self.is_null_unpack():
            return True
        if self.op == "RESHAPE" and self.is_null_reshape():
            return True
        if self.op == "POW" and self.is_null_pow():
            return True
        if self.op == "SQUEEZE" and (self.is_flatten_reshape() or self.is_rank4_to_rank3_reshape()):
            return True
        if self.op in MATH_OPS:
            return True
        if self.op == "TILE" and self.is_null_tile():
            return True
        if self.op == "EXPAND_DIMS" and self.is_null_expand_dims():
            return True

        return False

    def is_normalization(self):
        if self.op not in NORMALIZATION_OPS:
            return False

        if self.is_ew_op():
            return False

        if self.is_square_mul():
            return False

        if self.is_inv_pos_activation():
            return False

        return True

    def get_normalization_info(self):
        # Possible cases (always assume start from the first encountered operator):
        # 1. Add/Sub -> Mul/Div: normalization of the form (x-Mean(x))/(Std(x))
        # 2. Mul/Div -> Add/Sub: equivalent normalization of the form (x-Mean(x)*Std(x))/(Std(x))
        # 3. Add/Sub: normalization with Std(x)=1 (private case Neg has Std(x)=-1)
        # 4. Mul/Div: normalization with Mean(x)=0
        self._attrs_dict = {
            "fused_activation": self.get_ew_op_fused_activation(),
            "fused_activation_vertex": self,
        }
        if not self.output_shapes[0]:
            raise UnsupportedNormalizationLayerError(
                f"Could not find output shapes in node {self.name}",
            )

        if self.op == "NEG":
            # can also be seen as case #3
            return self._attrs_dict, [0.0], [-1.0], []

        values, input_idx = self.get_normalization_node_values()
        if values is None:
            raise UnsupportedNormalizationLayerError(
                f"Could not find normalization values in node {self.name}",
            )
        if self.op in ADD_OPS + SUB_OPS:
            # either add -mean or subtract mean
            mean = (np.negative(values) if self.op in ADD_OPS else values).tolist()
            mean = mean if isinstance(mean, list) else [mean]
            if len(mean) == 1:
                mean = mean * self.output_shapes[0][-1]

            # covering case #1
            std_node = None

            if self._attrs_dict["fused_activation"] == "NONE":
                mul_node = look_for_node(self._graph, self, [FwdChainNode(op="MUL")])
                div_node = look_for_node(self._graph, self, [FwdChainNode(op="DIV")])
                if mul_node and mul_node.is_normalization():
                    std_node = mul_node
                elif div_node and div_node.is_normalization():
                    std_node = div_node

            if std_node:
                self._attrs_dict["fused_activation"] = std_node.get_ew_op_fused_activation()
                self._attrs_dict["fused_activation_vertex"] = std_node
                std_params, _ = std_node.get_normalization_node_values()
                # either multiply by 1/std or divide by std
                std = (np.reciprocal(std_params) if std_node.op in MUL_OPS else std_params).tolist()
                std = std if isinstance(std, list) else [std]
                if len(std) == 1:
                    std = std * std_node.output_shapes[0][-1]
            else:
                std = [1.0]

            if self.op in SUB_OPS and input_idx == 0:  # Sub(c, x)
                std = [-1 * x for x in std]
            return self._attrs_dict, mean, std, [std_node] if std_node else []

        elif self.op in MUL_OPS + DIV_OPS:
            # covering case #4
            std_params = values
            if std_params is None:
                raise UnsupportedNormalizationLayerError(
                    f"Could not find std values in normalization starting from node {self.name}",
                )

            # either multiply by 1/std or divide by std (except for a mul-by-const scenario)
            if self.op in MUL_OPS:
                std_params = np.reciprocal(std_params)
                std_params = np.array([std_params]) if std_params.shape == () else std_params
                std_params[std_params == np.inf] = 0.0
            std = std_params.tolist()
            std = std if isinstance(std, list) else [std]
            if len(std) == 1:
                std = std * self.output_shapes[0][-1]

            # covering case #2: multiply mean by std
            mean_node = None
            if self._attrs_dict["fused_activation"] == "NONE":
                sub_node = look_for_node(self._graph, self, [FwdChainNode(op="SUB")])
                add_node = look_for_node(self._graph, self, [FwdChainNode(op="ADD")])
                if sub_node and sub_node.is_normalization():
                    mean_node = sub_node
                elif add_node and add_node.is_normalization():
                    mean_node = add_node

            # In case of Div(c, x) we cant parse sub \ add as the same normalization
            if mean_node and input_idx == 1:
                self._attrs_dict["fused_activation"] = mean_node.get_ew_op_fused_activation()
                self._attrs_dict["fused_activation_vertex"] = mean_node
                mean_params, mean_input_idx = mean_node.get_normalization_node_values()
                # either add -mean or subtract mean
                mean = (np.negative(mean_params) if mean_node.op in ADD_OPS else mean_params).tolist()
                mean = mean if isinstance(mean, list) else [mean]
                if len(mean) == 1:
                    mean = mean * mean_node.output_shapes[0][-1]
                mean = [x * y for x, y in zip(mean, std)]
                if mean_node.op in SUB_OPS and mean_input_idx == 0:  # Sub(c, x) after Div(x, d)
                    std = [-1 * x for x in std]
            else:
                mean = [0.0]

            return self._attrs_dict, mean, std, [mean_node] if mean_node else []

        raise UnsupportedNormalizationLayerError(
            f"Could not find std/mean values in normalization starting from node {self.name}",
        )

    def get_normalization_node_values(self):
        input_tensors = self.graph.values_by_vertex_name[self.name]

        if len(self.input) > 1 and self.input[1] in input_tensors:
            values = input_tensors[self.input[1]]
            input_idx = 1
        elif self.input[0] in input_tensors:
            values = input_tensors[self.input[0]]
            input_idx = 0
        elif self._has_dynamic_shape_const_input():
            values, input_idx = self._get_dynamic_shape_const_input()
        else:
            values = None
            input_idx = -1

        values = values.squeeze() if values is not None else None
        return values, input_idx

    def get_padding_info(self):
        if self.op in CONCAT_OPS:
            return self.get_external_pad_info_from_concat()

        input_tensors = self.graph.values_by_vertex_name[self.name]
        paddings = input_tensors.get(self.input[1], None)
        if paddings is None:
            raise UnsupportedPaddingError(f"Could not find paddings in node {self.name}")
        paddings = paddings.astype(int).flatten()[2:8].tolist()
        return paddings, []

    def get_concat_const_input_info(self):
        input_val = None
        input_name = None

        input_tensors = self.graph.values_by_vertex_name[self.name]
        valid_input = input_tensors and any(x in self.input for x in input_tensors)
        if valid_input and len(input_tensors) == 1:
            input_name = next(iter(input_tensors.keys()))
            input_val = np.squeeze(input_tensors[input_name], axis=0)

        return input_name, input_val

    def get_ew_const_input_info(self):
        input_val = None
        input_name = None

        input_tensors = self.graph.values_by_vertex_name[self.name]
        if input_tensors and len(self.input) == 2 and input_tensors.get(self.input[1]) is not None:
            input_name = self.input[1]
            input_val = input_tensors[input_name]

        return input_name, input_val

    def get_tile_multipliers_info(self):
        consumed_vertices = [self]
        input_tensors = self.graph.values_by_vertex_name[self.name]
        repeats = input_tensors.get(self.input[1], None)
        return repeats, consumed_vertices

    def get_concat_info(self):
        builtin_ops = self._info.BuiltinOptions()
        ops = ConcatenationOptions()
        ops.Init(builtin_ops.Bytes, builtin_ops.Pos)

        self._attrs_dict["axis"] = int(ops.Axis())
        if self._attrs_dict["axis"] == 0:
            raise UnsupportedConcatLayerError(f"Concat over the batch dimension is not supported in {self.name}.")
        self._attrs_dict["fused_activation"] = self.get_fused_activation_op(ops.FusedActivationFunction())

        return self._attrs_dict

    def is_external_pad_concat(self):
        input_tensors = self.graph.values_by_vertex_name[self.name]
        if len(self.input) == 3 and self.input[0] == self.input[2] and len(input_tensors) == 1:
            return True
        return False

    def get_external_pad_info_from_concat(self):
        consumed_vertices = [self]
        paddings = [0, 0, 0, 0, 0, 0]
        first_concat_input_tensors = self.graph.values_by_vertex_name[self.name]
        first_pads = first_concat_input_tensors[self.input[0]]
        first_concat_info = self.get_concat_info()
        first_axis = first_concat_info["axis"]
        first_pads_size = first_pads.shape[first_axis]
        if first_axis == 1:
            paddings[0:2] = [first_pads_size, first_pads_size]
        elif first_axis == 2:
            paddings[2:4] = [first_pads_size, first_pads_size]

        second_concat = look_for_node(self._graph, self, [FwdChainNode(op="CONCATENATION")])
        if second_concat and second_concat.is_external_pad_concat():
            consumed_vertices.append(second_concat)
            second_concat_input_tensors = self.graph.values_by_vertex_name[second_concat.name]
            second_pads = second_concat_input_tensors[second_concat.input[0]]
            second_concat_info = second_concat.get_concat_info()
            second_axis = second_concat_info["axis"]
            second_pads_size = second_pads.shape[second_axis]
            if second_axis == 1:
                paddings[0:2] = [second_pads_size, second_pads_size]
            elif second_axis == 2:
                paddings[2:4] = [second_pads_size, second_pads_size]

        return paddings, consumed_vertices

    def is_square_mul(self):
        return len(self.input) == 2 and self.input[0] == self.input[1]

    def verify_pow_is_square(self):
        if self.op == "POW":
            input_tensors = self.graph.values_by_vertex_name[self.name]
            pow_value = input_tensors.get(self.input[POW_INPUT_ORDER.index("Y")], [])
            pow_value = pow_value.tolist() if pow_value is not None else pow_value
            pow_value = [pow_value] if not isinstance(pow_value, list) else pow_value
            if len(pow_value) < 1 or not np.all([x == 2.0 for x in pow_value]):
                raise UnsupportedSquareLayerError(
                    f"Pow operator {self.name} can only be supported as square (got power of {pow_value})",
                )

    def get_feature_split_info(self):
        input_tensors = self.graph.values_by_vertex_name[self.name]
        if self.op == "SPLIT":
            order = SPLIT_INPUT_ORDER
        elif self.op == "SPLIT_V":
            order = SPLIT_V_INPUT_ORDER
        else:
            raise UnsupportedFeatureSplitterLayerError(f"Unsupported split layer at {self.name}, op={self.op}")

        axis = input_tensors.get(self.input[order.index("AXIS")], None)
        number_of_split = len(self.output)
        split_sizes = [shape[axis] for shape in self.output_shapes]

        if np.sum(split_sizes) != self.input_shapes[order.index("X")][axis]:
            raise UnsupportedFeatureSplitterError(
                f"Feature split node {self.name} must have an output shape where the features "
                "dimension is the sum of all splits sizes. "
                f"Found num_splits={number_of_split} and split_sizes={split_sizes}.",
            )

        io_indices = sorted(self.get_vertex_successors_io_indices().items(), key=lambda item: item[1][0].name)

        # creating one list of all split used outputs
        chosen_outputs = [
            successor_input
            for successor_inputs in [item[1][0].input for item in io_indices if item[1][0].input]
            for successor_input in successor_inputs
            if successor_input in self.output
        ]
        io_indices = [i for i, output in enumerate(self.output) if output in chosen_outputs]

        # TODO: Add support of using more than one output (and not all) of split layer - SDK-30108
        # More than 1 split is used, but not all splits
        if len(io_indices) > 1 and len(io_indices) != number_of_split:
            raise UnsupportedFeatureSplitterError(
                f"Feature split node {self.name} must have an output for each split, "
                f"or no more than 1 output. Found number_of_split={number_of_split} "
                f"and num_outputs={len(io_indices)}.",
            )

        split_indices = [sum(split_sizes[:io_index]) for io_index in io_indices]

        features_split_dims = []
        output_shapes = []
        for io_index in io_indices:
            features_split_dims.append(split_sizes[io_index])
            output_shapes.append(self.input_shapes[order.index("X")][:axis] + [split_sizes[io_index]])

        rank = len(self.input_shapes[order.index("X")])
        if not self.is_features_axis(axis, rank):
            raise UnsupportedFeatureSplitterLayerError(
                f"Feature splitter vertex {self.name} is splitting input over unsupported axis {axis}",
            )

        return features_split_dims, split_indices, output_shapes

    def get_resize_info(self):
        input_tensors = self.graph.values_by_vertex_name[self.name]
        resize_sizes = input_tensors.get(self.input[RESIZE_INPUT_ORDER.index("SIZES")], None)
        if resize_sizes is not None:
            resize_sizes = resize_sizes.tolist()

        ops = self._info.BuiltinOptions()

        if self.op == "RESIZE_BILINEAR":
            resize_method = ResizeMethod.bilinear
            resize_options = ResizeBilinearOptions()
        else:
            resize_method = ResizeMethod.nearest_neighbor
            resize_options = ResizeNearestNeighborOptions()

        resize_options.Init(ops.Bytes, ops.Pos)
        self._attrs_dict["align_corners"] = resize_options.AlignCorners()
        self._attrs_dict["half_pixel_centers"] = resize_options.HalfPixelCenters()
        pixels_mode = ResizeBilinearPixelsMode.disabled
        if self._attrs_dict["align_corners"]:
            pixels_mode = ResizeBilinearPixelsMode.align_corners
        elif self._attrs_dict["half_pixel_centers"]:
            pixels_mode = ResizeBilinearPixelsMode.half_pixels

        return resize_method, resize_sizes, pixels_mode

    def is_valid_reduce_max_min(self):
        keep_dims, axis = self.get_reduce_info()
        return keep_dims and self.is_features_axis(axis)

    def get_reduce_sum_info(self):
        keep_dims, axes = self.get_reduce_info()
        if axes is not None:
            axes = axes.tolist()
            if not isinstance(axes, list):
                axes = [axes]

            axes = [axis + 4 if axis < 0 else axis for axis in axes]

        return keep_dims and axes is not None and 0 not in axes, axes

    def get_reduce_l2_info(self):
        sum_node = look_for_node(self.graph, self, [FwdChainNode(op="SUM")])
        is_valid, axes = sum_node.get_reduce_sum_info()
        sqrt_node = look_for_node(self.graph, sum_node, [FwdChainNode(op="SQRT")])

        return is_valid, axes, [sum_node, sqrt_node]

    def is_null_transpose(self):
        if self.op != "TRANSPOSE":
            return False

        ones_in_output_shape = 0
        output_shape = self.output_shapes[0]
        rank = len(output_shape)
        for dim in output_shape[1:]:
            if dim == 1:
                ones_in_output_shape = ones_in_output_shape + 1
        return ones_in_output_shape >= rank - 2

    def get_null_vertices(self):
        next_node = None
        if self.op == "TRANSPOSE":
            next_node = look_for_node(self._graph, self, [FwdChainNode(op="RESHAPE")])
        elif self.op == "UNPACK":
            next_node = look_for_node(self._graph, self, [FwdChainNode(op="EXPAND_DIMS")])
        return [next_node] if next_node is not None else []

    def is_global_max_pool(self):
        if self.op != "REDUCE_MAX":
            return False

        _, axis = self.get_reduce_info()
        return self.is_spatial_axis(axis)

    def get_reduce_info(self):
        input_tensors = self.graph.values_by_vertex_name[self.name]
        axis = input_tensors.get(self.input[REDUCE_INPUT_ORDER.index("AXIS")], None)

        ops = self._info.BuiltinOptions()
        reduce_options = ReducerOptions()
        reduce_options.Init(ops.Bytes, ops.Pos)

        self._attrs_dict["keep_dims"] = reduce_options.KeepDims()

        return self._attrs_dict["keep_dims"], axis

    @staticmethod
    def is_features_axis(axis, rank=None):
        if axis is None:
            return False

        axis = axis.tolist()
        if isinstance(axis, list):
            if len(axis) > 1:
                return False

            axis = axis[0]

        if axis in [-1, 3] or (rank == 2 and axis == 1):
            return True

        return False

    @staticmethod
    def is_spatial_axis(axis):
        if axis is None:
            return False

        axis = axis.tolist()
        if not isinstance(axis, list):
            return False

        return axis == [1, 2]

    def is_shuffle(self):
        if self.op != "RESHAPE":
            return False

        return look_for_node(self._graph, self, [FwdChainNode(op="TRANSPOSE"), FwdChainNode(op="RESHAPE")]) is not None

    def get_transpose_perm(self):
        input_tensors = self.graph.values_by_vertex_name[self.name]
        return input_tensors[self.input[TRANSPOSE_INPUT_ORDER.index("PERM")]].tolist()

    def is_width_features_transpose(self):
        if self.op != "TRANSPOSE":
            return False

        perm = self.get_transpose_perm()
        return perm == [0, 1, 3, 2]

    def is_height_width_transpose(self):
        if self.op != "TRANSPOSE":
            return False

        perm = self.get_transpose_perm()
        return perm == [0, 2, 1, 3]

    def is_supported_transpose(self):
        return self.is_height_width_transpose() or self.is_width_features_transpose()

    def get_shuffle_reshape_transpose_info(self):
        transpose = look_and_validate(self._graph, self, [FwdChainNode(op="TRANSPOSE")])
        perm = transpose.get_transpose_perm()

        second_reshape_vertex = look_and_validate(self._graph, transpose, [FwdChainNode(op="RESHAPE")])

        return second_reshape_vertex, perm, [transpose, second_reshape_vertex]

    def get_reshape_shape_relations(self):
        # the function is used to identify depth to space with reshape
        input_shape = self.input_shapes[0]

        reshape = look_for_node(self._graph, self, [FwdChainNode(op="TRANSPOSE"), FwdChainNode(op="RESHAPE")])
        if reshape:
            # reshape-transpose-reshape style
            output_shape = reshape.output_shapes[0]
        else:
            # single reshape style
            output_shape = self.output_shapes[0]

        if len(input_shape) < 4 or len(output_shape) < 4:
            return -1, False, False, False

        factor = input_shape[3] // output_shape[3]

        is_same_batch_features = input_shape[0] == output_shape[0]
        is_width_factored = input_shape[2] * factor == output_shape[2]
        is_height_factored = input_shape[1] * factor == output_shape[1]

        return factor, is_same_batch_features, is_width_factored, is_height_factored

    def is_single_channel_reshape_depth_to_space(self):
        if self.op != "RESHAPE":
            return False, None

        # factor must be greater than to avoid identity combination
        factor, is_same_batch_features, is_width_factored, is_height_factored = self.get_reshape_shape_relations()
        single_channel_d2s_cond = is_same_batch_features and (is_width_factored or is_height_factored) and factor > 1
        block_size = [1, factor] if is_width_factored else [factor, 1]
        return single_channel_d2s_cond, block_size

    def get_depth_to_space_info(self):
        if self.op == "RESHAPE":
            # reshape depth to space
            consumed_vertices = get_all_nodes_in_chain(
                self._graph,
                self,
                [FwdChainNode(op="TRANSPOSE"), FwdChainNode(op="RESHAPE")],
            )
            if consumed_vertices:
                # reshape-transpose-reshape depth to space
                return consumed_vertices, consumed_vertices[-1].output_shapes

        return [], self.output_shapes

    def get_d2s_block_size(self):
        if self.op == "RESHAPE":
            # reshape or reshape-transpose-reshape depth to space
            single_channel_d2s, single_block_size = self.is_single_channel_reshape_depth_to_space()
            if single_channel_d2s:
                return single_block_size
            else:
                reshape_block_d2s, reshape_block_size = self.is_depth_to_space_reshape_block()
                if reshape_block_d2s:
                    return reshape_block_size
        else:
            # atomic depth to space
            builtin_ops = self._info.BuiltinOptions()
            ops = DepthToSpaceOptions()
            ops.Init(builtin_ops.Bytes, builtin_ops.Pos)
            self._attrs_dict["block_size"] = ops.BlockSize()
            return self._attrs_dict["block_size"]

    def get_l2_normalization_activation(self):
        builtin_ops = self._info.BuiltinOptions()
        ops = L2NormOptions()
        ops.Init(builtin_ops.Bytes, builtin_ops.Pos)
        self._attrs_dict["fused_activation"] = self.get_fused_activation_op(ops.FusedActivationFunction())
        return self._attrs_dict["fused_activation"]

    def get_ew_op_fused_activation(self):
        builtin_ops = self._info.BuiltinOptions()
        if self.op in ["NEG", "ADD_N", "MAXIMUM", "MINIMUM"]:
            # NEG, ADD_N, MAXIMUM doesn't have fused activation
            return self.get_fused_activation_op(None)
        elif self.op in ADD_OPS:
            ops = AddOptions()
        elif self.op in SUB_OPS:
            ops = SubOptions()
        elif self.op in MUL_OPS:
            ops = MulOptions()
        elif self.op in DIV_OPS:
            ops = DivOptions()
        else:
            raise UnsupportedEWLayerError(f"Unknown EW op type {self.op} for vertex {self.name}")

        ops.Init(builtin_ops.Bytes, builtin_ops.Pos)
        self._attrs_dict["fused_activation"] = self.get_fused_activation_op(ops.FusedActivationFunction())
        return self._attrs_dict["fused_activation"]

    def is_threshold_activation(self):
        if self.op == "GREATER":
            node = look_for_node(self._graph, self, [FwdChainNode(op="CAST"), FwdChainNode(op="MUL")])
            return node is not None

        if self.op == "MUL":
            node = look_for_node(self._graph, self, [BwdChainNode(op="CAST"), BwdChainNode(op="GREATER")])
            return node is not None

        return False

    def get_threshold_activation_values(self):
        cast, greater = get_all_nodes_in_chain(self._graph, self, [BwdChainNode(op="CAST"), BwdChainNode(op="GREATER")])
        input_tensors = self.graph.values_by_vertex_name[greater.name]
        thresh_value = input_tensors.get(greater.input[1], None)
        return thresh_value, [cast, greater]

    def get_activation_less_or_greater_values(self):
        consumed_vertices = []
        cast_node = look_for_node(self.graph, self, [FwdChainNode(op="CAST")])
        if cast_node:
            consumed_vertices.append(cast_node)

        input_tensors = self.graph.values_by_vertex_name[self.name]
        if len(input_tensors) == 0:
            raise UnsupportedActivationLayerError(
                f"Could not find input tensors for node {self.name}, elementwise greater/less is not supported."
            )
        unsqueezed_values = next(iter(input_tensors.values()))
        values = unsqueezed_values if len(unsqueezed_values.shape) == 1 else np.squeeze(unsqueezed_values, axis=0)
        const_name = next(iter(input_tensors.keys()))
        const_idx = self.input.index(const_name)

        output_mapping = {
            (0, "LESS"): ActivationType.greater,
            (0, "GREATER"): ActivationType.less,
            (1, "LESS"): ActivationType.less,
            (1, "GREATER"): ActivationType.greater,
        }

        return output_mapping[(const_idx, self.op)], values, consumed_vertices

    def is_softplus_activation(self):
        if self.op != "EXP":
            return False

        consumed_vertices = get_all_nodes_in_chain(self.graph, self, [FwdChainNode(op="ADD"), FwdChainNode(op="LOG")])
        if consumed_vertices is None:
            return False, []

        add = consumed_vertices[0]
        add_input_tensors = self.graph.values_by_vertex_name[add.name]
        add_value = add_input_tensors.get(add.input[0], add_input_tensors.get(add.input[1], None))
        if not add_value or add_value.tolist() != 1:
            return False, []

        return True, consumed_vertices

    def is_mish_activation(self):
        if self.op == "EXP":
            if self.is_softplus_activation()[0]:
                node = look_for_node(
                    self._graph,
                    self,
                    [FwdChainNode(op="ADD"), FwdChainNode(op="LOG"), FwdChainNode(op="TANH"), FwdChainNode(op="MUL")],
                )
                return node is not None

        elif self.op == "MUL":
            node = look_for_node(
                self._graph,
                self,
                [BwdChainNode(op="TANH"), BwdChainNode(op="LOG"), BwdChainNode(op="ADD"), BwdChainNode(op="EXP")],
            )
            return node is not None and node.is_softplus_activation()[0]

        return False

    def get_mish_activation_vertices(self):
        return get_all_nodes_in_chain(
            self._graph,
            self,
            [BwdChainNode(op="TANH"), BwdChainNode(op="LOG"), BwdChainNode(op="ADD"), BwdChainNode(op="EXP")],
        )

    def has_common_stem(self, other):
        return any(x in self._graph.predecessors(other) for x in self._graph.predecessors(self))

    def is_silu_activation(self):
        if self.op == "LOGISTIC":
            ew_mul = look_for_node(self._graph, self, [FwdChainNode("MUL")])
            if ew_mul and ew_mul.is_ew_mult():
                return self.has_common_stem(ew_mul)

        elif self.op == "MUL":
            sigmoid = look_for_node(self._graph, self, [BwdChainNode("LOGISTIC")])
            if sigmoid:
                return self.has_common_stem(sigmoid) and self.is_ew_mult()

        return False

    def get_silu_activation_vertices(self):
        return get_all_nodes_in_chain(self._graph, self, [BwdChainNode("LOGISTIC")])

    def is_swish_activation_first_mul(self):
        if self.op == "MUL":
            second_mul = look_for_node(self._graph, self, [FwdChainNode(op="LOGISTIC"), FwdChainNode(op="MUL")])
            return second_mul is not None and second_mul.is_ew_mult() and self.has_common_stem(second_mul)

        return False

    def is_swish_activation_second_mul(self):
        if self.is_ew_mult():
            first_mul = look_for_node(self._graph, self, [BwdChainNode(op="LOGISTIC"), BwdChainNode(op="MUL")])
            return first_mul is not None and self.has_common_stem(first_mul)

        return False

    def get_swish_beta(self):
        consumed_vertices = get_all_nodes_in_chain(
            self._graph,
            self,
            [BwdChainNode(op="LOGISTIC"), BwdChainNode(op="MUL")],
        )

        if not consumed_vertices:
            raise CantFindSwishBetaError(self.name)

        mul = consumed_vertices[-1]
        input_tensors = self.graph.values_by_vertex_name[mul.name]
        swish_beta = input_tensors.get(mul.input[0], input_tensors.get(mul.input[1], None))

        return swish_beta, consumed_vertices

    def is_biased_delta_activation(self):
        if self.op in CUSTOM_SIGN_OPS:
            return look_for_node(self._graph, self, [FwdChainNode(op="ABS")]) is not None

        if self.op == "ABS":
            possible_chains = [[FwdChainNode(op=sign)] for sign in CUSTOM_SIGN_OPS]
            return get_node_from_possible_chains(self._graph, self, possible_chains) is not None

        return False

    def get_biased_delta(self):
        if self.op in CUSTOM_SIGN_OPS:
            consumed_vertices = get_all_nodes_in_chain(self._graph, self, [FwdChainNode(op="ABS")])
        else:
            possible_chains = [[FwdChainNode(op=sign)] for sign in CUSTOM_SIGN_OPS]
            consumed_vertices = get_all_nodes_from_possible_chains(self._graph, self, possible_chains)

        if not consumed_vertices:
            raise CantFindBiasedDelatError(self.name)

        succs = list(self.graph.successors(consumed_vertices[-1]))
        if len(succs) == 1 and succs[0].op in MUL_OPS and not succs[0].is_ew_op():
            mul = succs[0]
            consumed_vertices.append(mul)
            input_tensors = self.graph.values_by_vertex_name[mul.name]
            delta = input_tensors.get(mul.input[0], input_tensors.get(mul.input[1], None))
        else:
            delta = 1

        return delta, consumed_vertices

    def is_square(self):
        if self.op in POW_OPS:
            return True

        if self.op in MUL_OPS:
            return self.is_square_mul()

        return False

    def is_reduce_l2(self):
        if not self.is_square():
            return False

        if len(list(self._graph.successors(self))) > 1:
            return False

        node = look_for_node(self.graph, self, [FwdChainNode(op="SUM"), FwdChainNode(op="SQRT")])
        return node is not None

    def get_slices_values(self, allow_stride=False):
        input_tensors = self.graph.values_by_vertex_name[self.name]
        if self.op == "STRIDED_SLICE":
            start_val = input_tensors[self.input[STRIDED_SLICE_INPUT_ORDER.index("BEGIN")]]
            stop_val = input_tensors[self.input[STRIDED_SLICE_INPUT_ORDER.index("END")]]
            step_val = input_tensors[self.input[STRIDED_SLICE_INPUT_ORDER.index("STRIDES")]]

            if len(start_val) > 4:
                raise UnsupportedSliceLayerError(
                    f"Illegal slice in node {self.name} - rank of slice is larger than 4 dimensions.",
                )

            builtin_ops = self._info.BuiltinOptions()
            ops = StridedSliceOptions()
            ops.Init(builtin_ops.Bytes, builtin_ops.Pos)

            self._attrs_dict["new_axis_mask"] = ops.NewAxisMask()
            self._attrs_dict["shrink_axis_mask"] = ops.ShrinkAxisMask()
            if self._attrs_dict["new_axis_mask"] != 0 or self._attrs_dict["shrink_axis_mask"] != 0:
                raise UnsupportedSliceLayerError(
                    f"Found new axis or shrink axis in slice node {self.name}, which is not supported",
                )

            self._attrs_dict["begin_mask_val"] = ops.BeginMask()
            begin_mask = [int(d) for d in str(bin(self._attrs_dict["begin_mask_val"])[2:])][::-1]
            begin_mask = np.append(begin_mask, (4 - len(begin_mask)) * [0]).astype(np.int32)

            self._attrs_dict["end_mask_val"] = ops.EndMask()
            end_mask = [int(d) for d in str(bin(self._attrs_dict["end_mask_val"])[2:])][::-1]
            end_mask = np.append(end_mask, (4 - len(end_mask)) * [0]).astype(np.int32)

            self._attrs_dict["ellipsis_mask"] = ops.EllipsisMask()
            ellipsis_str = str(bin(self._attrs_dict["ellipsis_mask"]))[:1:-1]
            ellipsis = ellipsis_str.index("1") if "1" in ellipsis_str else -1

            start, stop, step = [0] * 4, [0] * 4, [1] * 4
            dim_index = 0
            for i, curr in enumerate(start_val):
                # If ellipsis is used in this dim, need to find how many more dims are specified to be sliced,
                # and skip those that ignored due to the ellipsis
                if i == ellipsis:
                    dims_to_skip = 4 - len(start_val)
                    dim_index += dims_to_skip + 1
                    continue
                # If begin_mask[i] == 1 then the start value for dim i is 0
                if begin_mask[i] == 0:
                    start[dim_index] = curr
                # If end_mask[i] == 1 then the end value for dim i is 0(later changed to output_shape[i])
                if end_mask[i] == 0:
                    stop[dim_index] = stop_val[i]
                step[dim_index] = step_val[i]
                dim_index += 1

            start = start[1:]
            stop = stop[1:]
            step = step[1:]
            if any(x > 1 or x < 0 for x in step[1:]) and not allow_stride:
                raise UnsupportedSliceLayerError(
                    f"Slices with stride > 1 or stride < 0 in width or features axis "
                    f"in node {self.name} are not supported.",
                )
            if step[0] < 0:
                raise UnsupportedSliceLayerError(
                    f"Slices with stride < 0 in height axis in node {self.name} are not supported.",
                )
        else:
            # Slice op
            start_val = input_tensors[self.input[SLICE_INPUT_ORDER.index("BEGIN")]]
            if len(self.input) > 2:
                size_key = self.input[SLICE_INPUT_ORDER.index("SIZE")]
                if size_key not in input_tensors:
                    raise UnsupportedSliceLayerError(f"Size input tensor not found in node {self.name}")
                size_val = input_tensors[size_key]
                crop_size = np.append(size_val[1:3], [0]).astype(np.int32)
            else:
                pack_node = look_for_node(self.graph, self, [BwdChainNode(op="PACK")])
                if pack_node:
                    pack_values = [x.tolist() for x in self.graph.values_by_vertex_name[pack_node.name]]
                    crop_size = np.append(pack_values, [0]).astype(np.int32)

            start = np.append(start_val[1:3], [0]).astype(np.int32)
            stop = [sum(x) if x[1] != -1 else 0 for x in zip(start, crop_size)]
            step = [1] * 3

        (height_slices, width_slices, feature_slices) = (
            [start[0], stop[0], step[0]],
            [start[1], stop[1], step[1]],
            [start[2], stop[2], step[2]],
        )

        return height_slices, width_slices, feature_slices

    def is_space_to_depth_slice_block(self):
        # This function now allows only space to depth with block size == 2
        # We expect to see 4 pairs of slices, where each pair has a different [height_slice_start, width_slice_start],
        # and overall to see these pairs: [[0,0], [0,1], [1,0], [1,1]]
        if self.op not in SLICE_OPS:
            return False

        pred = next(iter(self._graph.predecessors(self)))
        input_shape = self.input_shapes[0]
        if len(input_shape) != 4:
            return False

        _, input_height, input_width, input_features = input_shape
        concat_node = None
        slice_start_pairs = []
        successors = list(self._graph.successors(pred))
        for succ in successors:
            if succ.op not in SLICE_OPS:
                return False

            next_successors = list(self._graph.successors(succ))
            if len(next_successors) != 1 or next_successors[0].op not in CONCAT_OPS:
                return False

            if concat_node is None:
                concat_node = next_successors[0]
            elif next_successors[0] != concat_node:
                return False

            slice_values = succ.get_slices_values(allow_stride=True)
            if (
                slice_values[0][1] not in [0, input_height]
                or slice_values[0][2] != 2
                or slice_values[1][1] not in [0, input_width]
                or slice_values[1][2] != 2
                or slice_values[2][0] != 0
                or slice_values[2][1] not in [0, input_features]
                or slice_values[2][2] != 1
            ):
                return False

            slice_start_pairs.append([slice_values[0][1], slice_values[1][0]])

        return len(slice_start_pairs) == 4 and all(
            pair in [[0, 0], [0, 1], [1, 0], [1, 1]] for pair in slice_start_pairs
        )

    def is_depth_to_space_reshape_block(self):
        if self.op != "RESHAPE":
            return False, []

        # case #1 - edge case for variant with single channel asymmetric block sized d2s
        is_single_channel_reshape, block_sizes = self.is_single_channel_reshape_depth_to_space()
        if is_single_channel_reshape:
            return is_single_channel_reshape, block_sizes

        input_shape = self.input_shapes[0]
        output_shape = self.output_shapes[0]
        if len(input_shape) < 4 or len(output_shape) < 4:
            return False, []

        # case #2 - classic combination of reshape+transpoose+reshape that defines a default space_to_depth of type DCR
        possible_chains = [[FwdChainNode(op="TRANSPOSE"), FwdChainNode(op="RESHAPE")]]
        nodes = get_all_nodes_from_possible_chains(self._graph, self, possible_chains)
        cond2s_reshape_block_cond = False
        if nodes is not None:
            transpose = nodes[0]
            second_reshape = nodes[1]
            first_reshape_cond = len(output_shape) == 6
            second_reshape_cond = len(second_reshape.output_shapes[0]) == 4
            perm = transpose.get_transpose_perm()
            transpose_perm_cond = perm == [0, 1, 3, 2, 4, 5]  # TF implementation is equivalent to DCR

            block_output_shape = second_reshape.output_shapes[0]
            block_sizes = self.output_shapes[0][3:5]
            block_sizes_cond = block_sizes[0] == block_sizes[1]
            d2s_shape_cond = (
                block_output_shape[1] == input_shape[1] * block_sizes[0]
                and block_output_shape[2] == input_shape[2] * block_sizes[0]
                and block_output_shape[3] == int(input_shape[3] // (block_sizes[0] * block_sizes[0]))
            )

            cond2s_reshape_block_cond = (
                first_reshape_cond
                and second_reshape_cond
                and transpose_perm_cond
                and block_sizes_cond
                and d2s_shape_cond
            )
        else:
            # case #3 - numpy reshape only arrangement of depth_to_space, with input of 1x1 spatial dimension
            spatial_dim_cond = input_shape[1] == 1 and input_shape[2] == 1
            output_shape_cond = len(output_shape) == 4
            block_sizes = [int(output_shape[1] // input_shape[1]), int(output_shape[2] // input_shape[2])]
            ratio_cond = (
                input_shape[1] * block_sizes[0] == output_shape[1]
                and input_shape[2] * block_sizes[1] == output_shape[2]
                and int(input_shape[3] // (block_sizes[0] * block_sizes[1])) == output_shape[3]
            )
            cond2s_reshape_block_cond = spatial_dim_cond and ratio_cond and output_shape_cond
        return cond2s_reshape_block_cond, block_sizes

    def is_relu6_clip(self):
        if self.op != "MINIMUM":
            return False

        input_tensors = self.graph.values_by_vertex_name[self.name]
        max_value = input_tensors.get(self.input[MINIMUM_INPUT_ORDER.index("CONST")], None)
        relu = look_for_node(self._graph, self, [FwdChainNode(op="RELU")])
        return relu is not None and max_value is not None and np.all(max_value == 6)

    def is_hardsigmoid(self):
        if self.op != "MINIMUM":
            return False

        input_tensors = self.graph.values_by_vertex_name[self.name]
        max_value = input_tensors.get(self.input[MINIMUM_INPUT_ORDER.index("CONST")], None)
        relu = look_for_node(self._graph, self, [FwdChainNode(op="RELU")])
        return relu is not None and max_value is not None and np.all(max_value == 1)

    def get_hardsigmoid_info(self):
        # alpha and beta always 1.0 and 0.0 respectively in tflite
        return 1.0, 0.0, [look_for_node(self._graph, self, [FwdChainNode(op="RELU")])]

    def get_min_max_info(self):
        input_tensors = self.graph.values_by_vertex_name[self.name]
        consumed_vertex = []
        if self.op == "MINIMUM":
            max_value = input_tensors.get(self.input[MINIMUM_INPUT_ORDER.index("CONST")], None)

            max_node = look_for_node(self.graph, self, [FwdChainNode(op="MAXIMUM")])
            if max_node:
                input_tensors = self.graph.values_by_vertex_name[max_node.name]
                # structure of (min) -> (max) perform by clipping of [min_value, max_value]
                min_value = input_tensors.get(max_node.input[MAXIMUM_INPUT_ORDER.index("CONST")], None)
                consumed_vertex.append(max_node)
            else:
                # structure of (min) perform by clipping of [-np.inf, max_value]
                min_value = -np.inf

        elif self.op == "MAXIMUM":
            min_value = input_tensors.get(self.input[MAXIMUM_INPUT_ORDER.index("CONST")], None)

            min_node = look_for_node(self.graph, self, [FwdChainNode(op="MINIMUM")])
            if min_node:
                # structure of (max) -> (min) perform by clipping of [min_value, max_value]
                input_tensors = self.graph.values_by_vertex_name[min_node.name]
                max_value = input_tensors.get(min_node.input[MINIMUM_INPUT_ORDER.index("CONST")], None)
                consumed_vertex.append(min_node)
            else:
                # structure of (max) perform by clipping of [min_value, np.inf]
                max_value = np.inf

        return min_value, max_value, consumed_vertex

    def get_space_to_depth_block_size(self):
        builtin_ops = self._info.BuiltinOptions()
        ops = SpaceToDepthOptions()
        ops.Init(builtin_ops.Bytes, builtin_ops.Pos)
        self._attrs_dict["block_size"] = ops.BlockSize()

        return self._attrs_dict["block_size"]

    def get_space_to_depth_consumed_vertices(self):
        if self.op not in SLICE_OPS:
            return []

        consumed_vertices = []
        pred = next(iter(self.graph.predecessors(self)))
        slices = list(self.graph.successors(pred))
        consumed_vertices.extend(slices)
        concat = next(iter(self.graph.successors(slices[0])))
        consumed_vertices.append(concat)

        return consumed_vertices

    def is_features_reshape(self):
        output_shape = self.output_shapes[0]
        input_shape = self.input_shapes[0]

        if not (len(input_shape) == len(output_shape) == 4):
            return False

        return (
            input_shape[0:2] == output_shape[0:2]
            and input_shape[2] == 1
            and input_shape[3] == output_shape[2] * output_shape[3]
        )

    def is_flat_to_frames_reshape(self):
        output_shape = self.output_shapes[0]
        input_shape = self.input_shapes[0]

        if len(input_shape) != 2:
            return False

        if len(output_shape) == 3:
            return output_shape[1] * output_shape[2] == input_shape[1]

        if len(output_shape) == 4:
            return (output_shape[2] * output_shape[3] == input_shape[1]) and output_shape[1] == 1

        return False

    def is_shape_op(self):
        return self.op == "SHAPE"

    def get_start_node_preds(self):
        preds = list(self._graph.predecessors(self))
        dequantize_preds = [x for x in preds if x.op in DEQUANTIZE_OPS]

        if len(preds) > 1 and not (
            (self.op in ADD_OPS and len(preds) == 2)
            or (self.op in CONCAT_OPS and len(preds) >= 2)
            or (len(dequantize_preds) + 1 == len(preds))
        ):
            raise CantFindGraphStartError(
                f"Start node {self.name} of type {self.op} has illegal number of input "
                f"nodes({len(preds)}), which is not supported.",
            )

        return [x for x in preds if x not in dequantize_preds]

    def get_input_shapes(self):
        return self.input_shapes

    def get_output_shapes(self, **kwrags):
        return self.output_shapes

    def get_input_layer_shapes(self):
        return self.output_shapes

    def is_pre_layer_op(self):
        return self.op in PRE_LAYER_OPS

    def get_dynamic_kernel_shape(self):
        if len(self.input_shapes) == 3:
            shape = self.input_shapes[1]
            return [*shape[1:], shape[0]]

        return None

    @staticmethod
    def is_const():
        return False

    def is_inv_pos_activation(self):
        if self.op not in DIV_OPS:
            return False

        input_tensors = self.graph.values_by_vertex_name[self.name]
        if len(input_tensors) == 1 and next(iter(input_tensors.keys())) in self.input[0]:
            if input_tensors.get(self.input[0], None) == 1:
                return True
        return False

    def get_normalization_activation(self, consumed_vertices):
        for vertex in [self, *consumed_vertices]:
            if vertex.op in DIV_OPS:
                _, input_idx = vertex.get_normalization_node_values()
                if input_idx == 0:  # Div(c, x)
                    return ActivationType.inv_pos
        return ActivationType.linear

    def get_pack_values_count(self):
        builtin_ops = self._info.BuiltinOptions()
        ops = PackOptions()
        ops.Init(builtin_ops.Bytes, builtin_ops.Pos)
        return ops.ValuesCount()

    def get_pack_axes(self):
        builtin_ops = self._info.BuiltinOptions()
        ops = PackOptions()
        ops.Init(builtin_ops.Bytes, builtin_ops.Pos)
        axis = ops.Axis()
        return axis if axis > 0 else axis + len(self.input_shapes[0])

    def get_resize_pack_info(self):
        """the function is used to extract pack-pack-reshape/pack-reshape-pack-reshape forward-chain info:
        current vertex is pack
        consumed_vertices[0] is pack (if pack-pack-reshape) or reshape (if pack-reshape-pack-reshape)
        consumed_vertices[1] is reshape (if pack-pack-reshape) or pack (if pack-reshape-pack-reshape)
        consumed_vertices[2] is reshape (if pack-reshape-pack-reshape)
        """
        possible_chains = [
            [FwdChainNode(op="PACK"), FwdChainNode(op="RESHAPE")],
            [FwdChainNode(op="RESHAPE"), FwdChainNode(op="PACK"), FwdChainNode(op="RESHAPE")],
        ]
        consumed_vertices = get_all_nodes_from_possible_chains(self._graph, self, possible_chains)

        second_pack_vertex = consumed_vertices[0]
        if len(consumed_vertices) == 3:
            # pack-pack-reshape case
            second_pack_vertex = consumed_vertices[1]
        last_reshape_vertex = consumed_vertices[-1]
        self._attrs_dict["w_sizes"] = self.get_pack_values_count()
        self._attrs_dict["h_sizes"] = second_pack_vertex.get_pack_values_count()
        self._attrs_dict["d_sizes"] = None
        self._attrs_dict["resize_method"] = ResizeMethod.nearest_neighbor
        self._attrs_dict["pixels_mode"] = ResizeBilinearPixelsMode.disabled
        self._attrs_dict["output_shapes"] = last_reshape_vertex.output_shapes
        return self._attrs_dict, consumed_vertices

    def is_pack_resize(self):
        """check if the current vertex is part of pack-reshape forward chain, current supported cases are:
        1. rank 4, pack(width axis)->pack(height axis)->reshape
        2. rank 4, pack(width axis)->reshape->pack(height axis)->reshape"""
        if self.op not in PACK_OPS:
            return False

        possible_chains = [
            [FwdChainNode(op="PACK"), FwdChainNode(op="RESHAPE")],
            [FwdChainNode(op="RESHAPE"), FwdChainNode(op="PACK"), FwdChainNode(op="RESHAPE")],
        ]
        consumed_vertices = get_all_nodes_from_possible_chains(self._graph, self, possible_chains)

        if not consumed_vertices:
            return False

        second_pack_vertex = consumed_vertices[0] if len(consumed_vertices) == 2 else consumed_vertices[1]

        if len(self.input_shapes[0]) == 4 and self.get_pack_axes() == 3 and second_pack_vertex.get_pack_axes() == 2:
            return True

        return False

    def is_ew_op_with_multi_non_const_inputs(self):
        non_const_preds = [x for x in self.graph.predecessors(self) if not x.is_const()]
        return self.op in EW_OPS and len(non_const_preds) > 1

    def is_optional_graph_start_node(self):
        return (
            not self.is_const()
            and not self.is_shape_op()
            and not self.is_ew_op_with_multi_non_const_inputs()
            and self.op not in MUL_OPS + OUTPUT_OPS + LOGICAL_OPS
        )


class TFLiteGraph(NNGraph):
    def __init__(self, model, values):
        super().__init__(model, values)
        self._model_graph = self._raw_proto.Subgraphs(0)
        self._model_tensors = [self._model_graph.Tensors(i) for i in range(self._model_graph.TensorsLength())]
        self._model_tensors_names = [t.Name().decode("utf-8") for t in self._model_tensors]
        self._model_operators = [self._model_graph.Operators(i) for i in range(self._model_graph.OperatorsLength())]
        self._init_vertices()
        self._init_vertices_connections()

    def _init_vertices(self):
        inputs = [self._model_tensors[i] for i in self._model_graph.InputsAsNumpy() if i > -1]
        for inp in inputs:
            vertex = TFLiteGraphNode(inp, self, is_input_vertex=True)
            self.add_node(vertex)
            self.add_vertex_by_name(vertex)

        for op in self._model_operators:
            vertex = TFLiteGraphNode(op, self)
            self.add_node(vertex)
            self.add_vertex_by_name(vertex)

    def _init_vertices_connections(self):
        self._multiple_io_vertices = []
        for vertex in self._vertices_by_name.values():
            if vertex.op != INPUT_OP and len(vertex.output) > 1:
                vertex.output = [
                    f"{vertex.name}:{x.split(vertex.name)[-1]}" if x != vertex.name else x for x in vertex.output
                ]
                self._multiple_io_vertices.append(vertex)

        for vertex in self._vertices_by_name.values():
            if vertex.op != INPUT_OP:
                for i, inp in enumerate(vertex.input):
                    for mult_io_vertex in self._multiple_io_vertices:
                        if inp == mult_io_vertex.name:
                            key = inp
                        else:
                            key = f"{mult_io_vertex.name}:{inp.split(mult_io_vertex.name)[-1]}"
                        if key in mult_io_vertex.output:
                            vertex.input[i] = key

        self._values_by_vertex_name = {}
        for vertex in self._vertices_by_name.values():
            self._values_by_vertex_name[vertex.name] = {}
            if vertex.op != INPUT_OP:
                for inp_name in vertex.input:
                    inp_name_src = inp_name.split(":")[0] if ":0" not in inp_name else inp_name
                    if inp_name_src in self._vertices_by_name:
                        self.add_edge(self._vertices_by_name[inp_name_src], vertex)

                    possible_values = {x: y for x, y in self._values.items() if inp_name_src in x}
                    self._values_by_vertex_name[vertex.name].update(possible_values)

    @property
    def net_input(self):
        input_names = [self._model_tensors_names[i] for i in self._model_graph.InputsAsNumpy() if i > -1]
        return [self.get_vertex_by_name(name) for name in input_names]

    @property
    def vertices_by_name(self):
        return self._vertices_by_name

    @property
    def values_by_vertex_name(self):
        return self._values_by_vertex_name
