#!/usr/bin/env python
import copy
import os
import re
from collections import namedtuple
from enum import Enum

import networkx as nx
import numpy as np
import tensorflow as tf
from past.utils import old_div

from hailo_model_optimization.acceleras.utils.acceleras_definitions import DEFAULT_CONCAT_AXIS, ConcatAxis
from hailo_sdk_client.model_translator.exceptions import (
    BatchNormParsingError,
    CantFindGraphStartError,
    CantFindLeakyAlphaError,
    CantFindPReluSlopeError,
    CantFindSigmoidParametersError,
    CantFindSwishBetaError,
    CantFindThresholdError,
    UnexpectedNodeError,
    UnsupportedActivationLayerError,
    UnsupportedAddLayerError,
    UnsupportedConcatLayerError,
    UnsupportedConvLayerError,
    UnsupportedDilationError,
    UnsupportedFeatureSplitterError,
    UnsupportedModelError,
    UnsupportedNormalizationLayerError,
    UnsupportedPaddingError,
    UnsupportedPoolingLayerError,
    UnsupportedReduceMeanLayerError,
    UnsupportedResizeLayerError,
    UnsupportedSliceLayerError,
)
from hailo_sdk_client.model_translator.graph_lookup import (
    BwdChainNode,
    FwdChainNode,
    get_all_nodes_in_chain,
    get_node_from_possible_chains,
    look_and_validate,
    look_for_node,
    test_scope,
)
from hailo_sdk_client.model_translator.nn_graph import NNGraph, NNGraphNode
from hailo_sdk_client.model_translator.tf_translator.exceptions import TFVariableError
from hailo_sdk_common.compatibility import ensure_str
from hailo_sdk_common.hailo_nn.hn_definitions import PaddingType, ResizeBilinearPixelsMode
from hailo_sdk_common.hailo_nn.hn_layers import BatchNormValues
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.paths_manager.paths import SDKPaths
from hailo_sdk_common.tools.models_translator_helper import is_spatial_broadcast

ResizeInfo = namedtuple("ResizeInfo", ["expected_output_shape", "forced_by_unknown_shape", "is_upscale_factors"])

INPUT_OPS = ["Placeholder", "PlaceholderV2", "PlaceholderWithDefault", "IteratorGetNext"]
PAD_OPS = ["Pad", "PadV2"]
CONCAT_OPS = ["Concat", "ConcatV2"]
CONV2D_OPS = ["Conv2D", "DepthwiseConv2dNative", "Conv2DBackpropInput"]
DILATION_OPS = ["SpaceToBatchND", "BatchToSpaceND"]
DENSE_OPS = ["MatMul"]
EINSUM_OPS = ["Einsum"]
POOL_OPS = ["AvgPool", "MaxPool"]
BN_OPS = ["BatchNorm", "FusedBatchNorm", "FusedBatchNormV3"]
LOGITS_OPS = ["Softmax", "ArgMax"]
REDUCE_MAX_OPS = ["Max"]
REDUCE_SUM_OPS = ["Sum"]
ADD_OPS = ["Add", "AddV2"]
ADD_N_OPS = ["AddN"]
BIAS_ADD_OPS = ["BiasAdd"]
SLICE_OPS = ["StridedSlice", "Slice"]
ACTIVATION_OPS = [
    "Relu",
    "Relu6",
    "Elu",
    "LeakyRelu",
    "Sigmoid",
    "Exp",
    "Tanh",
    "Softplus",
    "Erf",
    "Sqrt",
    "Less",
    "Log",
    "Softsign",
]
SPLIT_OPS = ["Split", "SplitV"]
RESIZE_OPS = ["ResizeNearestNeighbor", "ResizeBilinear"]
SHUFFLE_OPS = ["Reshape", "Transpose", "DepthToSpace"]
SPACE_TO_DEPTH_OPS = ["SpaceToDepth"]
NORMALIZATION_OPS = [*ADD_OPS, "Sub", "Mul", "RealDiv", "Neg"]
SQUARE_OPS = ["Square", "Pow", "Mul"]
DIV_OPS = ["RealDiv", "Reciprocal"]

MATH_OPS = [
    "Mul",
    "Sub",
    "Cast",
    "Floor",
    "Rsqrt",
    "Minimum",
    "Maximum",
    "Greater",
    "Mean",
    "RealDiv",
    "Sign",
    "Abs",
    "Neg",
]
VAR_OPS = ["Const", "Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"]

GRAPH_START_OPS = PAD_OPS + CONV2D_OPS + DENSE_OPS + CONCAT_OPS + ADD_OPS + BIAS_ADD_OPS + RESIZE_OPS
NEW_LAYER_OPS = (
    PAD_OPS
    + CONCAT_OPS
    + CONV2D_OPS
    + ACTIVATION_OPS
    + SPLIT_OPS
    + RESIZE_OPS
    + SPACE_TO_DEPTH_OPS
    + EINSUM_OPS
    + DENSE_OPS
)
NEW_LAYER_OPS += [
    "MaxPool",
    "AvgPool",
    "Maximum",
    "SpaceToBatchND",
    "DepthToSpace",
    "ArgMax",
    "Softmax",
    "StridedSlice",
    "Mean",
    "If",
    "Square",
    "Pow",
]

OTHER_OPS = [
    "Identity",
    "Switch",
    "Merge",
    "ExpandDims",
    "Squeeze",
    "Shape",
    "If",
    "IdentityN",
    "StatelessIf",
    "SelectV2",
]

SUPPORTED_OPS_UNION = (
    INPUT_OPS
    + PAD_OPS
    + CONCAT_OPS
    + CONV2D_OPS
    + DILATION_OPS
    + DENSE_OPS
    + POOL_OPS
    + BN_OPS
    + LOGITS_OPS
    + ADD_OPS
    + BIAS_ADD_OPS
    + SLICE_OPS
    + ACTIVATION_OPS
    + SPLIT_OPS
    + RESIZE_OPS
    + SHUFFLE_OPS
    + MATH_OPS
    + OTHER_OPS
    + VAR_OPS
    + REDUCE_MAX_OPS
    + SPACE_TO_DEPTH_OPS
    + REDUCE_SUM_OPS
    + EINSUM_OPS
    + SQUARE_OPS
    + ADD_N_OPS
    + DIV_OPS
)

BATCH_NORM_NAMES = ["batch_normalization", "batchnorm", "batch_norm", "_bn", "bn_"]
BATCH_NORM_INPUT_ORDER = ["X", "gamma", "beta", "moving_mean", "moving_var"]
DROPOUT_NAMES = ["dropout", "Dropout"]
TRAINING_NAMES = ["train", "learning_phase"]
DEFAULT_BN_EPSILON = 1e-3


class AddOpRole(Enum):
    ew = "ew"
    bias = "bias"
    normalization = "normalization"


class BNParser:
    MUL_MOVING_MEAN_CHAINS = [
        [
            BwdChainNode(op="Mul"),
            FwdChainNode(op="Mul"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="moving_mean"),
        ],
        [
            FwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="moving_mean"),
        ],
        [
            FwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Merge"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="moving_mean"),
        ],
        # Frozen Graph
        [
            BwdChainNode(op="Mul"),
            FwdChainNode(op="Mul"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="moving_mean"),
        ],
        [
            FwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="moving_mean"),
        ],
        [
            FwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Merge"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="moving_mean"),
        ],
        [
            FwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Merge"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Const", name="moving_mean"),
        ],
        [
            FwdChainNode(op="Add"),
            FwdChainNode(op="Switch"),
            FwdChainNode(op="Merge"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="moving_mean"),
        ],
        [
            FwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="mean"),
        ],
        # TF1.13 chains that stem from mul op
        [
            FwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="ReadVariableOp"),
            BwdChainNode(op="VarHandleOp", name="moving_mean"),
        ],
        # Keras TF2 compatibility
        [
            BwdChainNode(op="Mul"),
            FwdChainNode(op="Mul"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="If"),
            BwdChainNode(op="VarHandleOp", name="moving_mean"),
        ],
        [
            FwdChainNode(op="Add"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Neg"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="If"),
            BwdChainNode(op="VarHandleOp", name="moving_mean"),
        ],
        [
            FwdChainNode(op="Add"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Neg"),
            BwdChainNode(op="ReadVariableOp"),
            BwdChainNode(op="VarHandleOp", name="moving_mean"),
        ],
        [
            FwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="If"),
            BwdChainNode(op="VarHandleOp", name="moving_mean"),
        ],
        # Chains that stem from rank 3
        [
            FwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Merge"),
            BwdChainNode(op="ReadVariableOp"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="VarHandleOp", name="moving_mean"),
        ],
        [
            FwdChainNode(op="Add"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Neg"),
            BwdChainNode(op="Merge"),
            BwdChainNode(op="ReadVariableOp"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="VarHandleOp", name="moving_mean"),
        ],
    ]

    MUL_MOVING_VAR_CHAINS = [
        [
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="moving_variance"),
        ],
        [
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="moving_variance"),
        ],
        [
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Merge"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="moving_variance"),
        ],
        [
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Merge"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="moving_variance"),
        ],
        # Frozen Graph
        [
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="moving_variance"),
        ],
        [
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="moving_variance"),
        ],
        [
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Merge"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="moving_variance"),
        ],
        [
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Merge"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="moving_variance"),
        ],
        [
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Merge"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Const", name="moving_variance"),
        ],
        [
            FwdChainNode(op="Add"),
            FwdChainNode(op="Switch"),
            FwdChainNode(op="Merge"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="moving_variance"),
        ],
        [
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="var"),
        ],
        # TF1.13 chains that stem from mul op
        [
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="ReadVariableOp"),
            BwdChainNode(op="VarHandleOp", name="moving_variance"),
        ],
        # Keras TF2 compatibility
        [
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="If"),
            BwdChainNode(op="VarHandleOp", name="moving_variance"),
        ],
        [
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="If"),
            BwdChainNode(op="VarHandleOp", name="moving_variance"),
        ],
        [
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="ReadVariableOp"),
            BwdChainNode(op="VarHandleOp", name="moving_variance"),
        ],
        # Chains that stem from rank 3
        [
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Merge"),
            BwdChainNode(op="ReadVariableOp"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="VarHandleOp", name="moving_variance"),
        ],
        [
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Merge"),
            BwdChainNode(op="ReadVariableOp"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="VarHandleOp", name="moving_variance"),
        ],
    ]

    MUL_BETA_CHAINS = [
        [
            FwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="beta"),
        ],
        # Frozen Graph
        [
            FwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="beta"),
        ],
        [
            FwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="bias"),
        ],
        # TF1.13 chains that stem from mul op
        [
            FwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="ReadVariableOp"),
            BwdChainNode(op="VarHandleOp", name="beta"),
        ],
    ]

    MUL_GAMMA_CHAINS = [
        [BwdChainNode(op="Mul"), BwdChainNode(op="Identity", name="read"), BwdChainNode(op="Variable", name="gamma")],
        # Frozen Graph
        [BwdChainNode(op="Mul"), BwdChainNode(op="Identity"), BwdChainNode(op="Const", name="gamma")],
        [BwdChainNode(op="Mul"), BwdChainNode(op="Identity"), BwdChainNode(op="Const", name="scale")],
        # TF1.13 chains that stem from mul op
        [BwdChainNode(op="Mul"), BwdChainNode(op="ReadVariableOp"), BwdChainNode(op="VarHandleOp", name="gamma")],
    ]

    MUL_EPSILON_CHAINS = [
        [BwdChainNode(op="Mul"), BwdChainNode(op="Rsqrt"), BwdChainNode(op="Add"), BwdChainNode(op="Const")],
        [BwdChainNode(op="Rsqrt"), BwdChainNode(op="Add"), BwdChainNode(op="Const")],
    ]

    FBN_MOVING_MEAN_CHAINS = [
        # Chains that cover cases induced by is_training = placeholder
        [
            FwdChainNode(op="FusedBatchNorm"),
            FwdChainNode(op="Merge"),
            BwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="moving_mean"),
        ],
        # Frozen Graph
        [
            FwdChainNode(op="FusedBatchNorm"),
            FwdChainNode(op="Merge"),
            BwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="moving_mean"),
        ],
        [
            FwdChainNode(op="FusedBatchNorm"),
            FwdChainNode(op="Merge"),
            BwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Const", name="moving_mean"),
        ],
        # Chains that cover cases induced by is_training = True
        [
            FwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="moving_mean"),
        ],
        [
            FwdChainNode(op="Sub"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="moving_mean"),
        ],
        # Frozen Graph
        [
            FwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="moving_mean"),
        ],
        [FwdChainNode(op="Sub"), BwdChainNode(op="Identity"), BwdChainNode(op="Const", name="moving_mean")],
        # Chains that cover cases induced by is_training = False
        [BwdChainNode(op="Identity", name="read"), BwdChainNode(op="Variable", name="moving_mean")],
        # Frozen Graph
        [BwdChainNode(op="Identity"), BwdChainNode(op="Const", name="moving_mean")],
        [BwdChainNode(op="Const", name="moving_mean")],
        # TF1.13 chains that stem from FBN op
        [BwdChainNode(op="ReadVariableOp"), BwdChainNode(op="VarHandleOp", name="moving_mean")],
        [
            FwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="ReadVariableOp"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="VarHandleOp", name="moving_mean"),
        ],
        [
            FwdChainNode(op="FusedBatchNorm"),
            FwdChainNode(op="Merge"),
            BwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="ReadVariableOp"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="VarHandleOp", name="moving_mean"),
        ],
        [
            FwdChainNode(op="Switch"),
            FwdChainNode(op="Merge"),
            BwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="ReadVariableOp"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="VarHandleOp", name="moving_mean"),
        ],
        # Chains that cover different implementation of BN (keras?)
        [
            FwdChainNode(op="Mul"),
            FwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Variable", name="moving_mean"),
        ],
        [
            FwdChainNode(op="Switch"),
            FwdChainNode(op="Merge"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Variable", name="moving_mean"),
        ],
        # Frozen Graph
        [
            FwdChainNode(op="Mul"),
            FwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="moving_mean"),
        ],
        [
            FwdChainNode(op="Switch"),
            FwdChainNode(op="Merge"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="moving_mean"),
        ],
        # Keras TF2 compatibilty
        [FwdChainNode(op="Sub"), BwdChainNode(op="ReadVariableOp"), BwdChainNode(op="VarHandleOp", name="moving_mean")],
    ]

    FBN_MOVING_VAR_CHAINS = [
        # Chains that cover cases induced by is_training = placeholder
        [
            FwdChainNode(op="FusedBatchNorm"),
            FwdChainNode(op="Merge"),
            BwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="moving_variance"),
        ],
        # Frozen Graph
        [
            FwdChainNode(op="FusedBatchNorm"),
            FwdChainNode(op="Merge"),
            BwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="moving_variance"),
        ],
        [
            FwdChainNode(op="FusedBatchNorm"),
            FwdChainNode(op="Merge"),
            BwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Const", name="moving_variance"),
        ],
        # Chains that cover cases induced by is_training = True
        [
            FwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="moving_variance"),
        ],
        [
            FwdChainNode(op="Sub"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="moving_variance"),
        ],
        # Frozen Graph
        [
            FwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Variable", name="moving_variance"),
        ],
        [FwdChainNode(op="Sub"), BwdChainNode(op="Identity"), BwdChainNode(op="Const", name="moving_variance")],
        # Chains that cover cases induced by is_training = False
        [BwdChainNode(op="Identity", name="read"), BwdChainNode(op="Variable", name="moving_variance")],
        [
            FwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="ReadVariableOp"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="VarHandleOp", name="moving_variance"),
        ],
        [
            FwdChainNode(op="FusedBatchNorm"),
            FwdChainNode(op="Merge"),
            BwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="ReadVariableOp"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="VarHandleOp", name="moving_variance"),
        ],
        [
            FwdChainNode(op="Switch"),
            FwdChainNode(op="Merge"),
            BwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="ReadVariableOp"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="VarHandleOp", name="moving_variance"),
        ],
        # Frozen Graph
        [BwdChainNode(op="Identity"), BwdChainNode(op="Const", name="moving_variance")],
        [BwdChainNode(op="Const", name="moving_variance")],
        [
            FwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Const", name="moving_variance"),
        ],
        # TF1.13 chains that stem from FBN op
        [BwdChainNode(op="ReadVariableOp"), BwdChainNode(op="VarHandleOp", name="moving_variance")],
        # Chains that cover different implementation of BN (keras?)
        [
            FwdChainNode(op="Mul"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="moving_variance"),
        ],
        [
            FwdChainNode(op="Switch"),
            FwdChainNode(op="Merge"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="moving_variance"),
        ],
        # Frozen Graph
        [
            FwdChainNode(op="Mul"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="moving_variance"),
        ],
        [
            FwdChainNode(op="Switch"),
            FwdChainNode(op="Merge"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="moving_variance"),
        ],
        # Keras TF2 compatibilty
        [
            FwdChainNode(op="Sub"),
            BwdChainNode(op="ReadVariableOp"),
            BwdChainNode(op="VarHandleOp", name="moving_variance"),
        ],
    ]

    FBN_GAMMA_CHAINS = [
        [
            FwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="gamma"),
        ],
        [BwdChainNode(op="Identity", name="read"), BwdChainNode(op="Variable", name="gamma")],
        [
            FwdChainNode(op="Mul"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="gamma"),
        ],
        # Frozen Graph
        [
            FwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="gamma"),
        ],
        [BwdChainNode(op="Identity"), BwdChainNode(op="Const", name="gamma")],
        [BwdChainNode(op="Const", name="gamma")],
        [
            FwdChainNode(op="Mul"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="gamma"),
        ],
        # TF1.13 chains that stem from FBN op
        [BwdChainNode(op="ReadVariableOp"), BwdChainNode(op="VarHandleOp", name="gamma")],
        [
            FwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="ReadVariableOp"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="VarHandleOp", name="gamma"),
        ],
        # Frozen Graph
        [
            FwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Const", name="gamma"),
        ],
        # Keras TF2 compatibilty
        [BwdChainNode(op="ReadVariableOp"), BwdChainNode(op="VarHandleOp", name="gamma")],
    ]

    FBN_BETA_CHAINS = [
        [
            FwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="beta"),
        ],
        [BwdChainNode(op="Identity", name="read"), BwdChainNode(op="Variable", name="beta")],
        [
            FwdChainNode(op="Mul"),
            FwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity", name="read"),
            BwdChainNode(op="Variable", name="beta"),
        ],
        # Frozen Graph
        [
            FwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="beta"),
        ],
        [BwdChainNode(op="Identity"), BwdChainNode(op="Const", name="beta")],
        [BwdChainNode(op="Const", name="beta")],
        [
            FwdChainNode(op="Mul"),
            FwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Const", name="beta"),
        ],
        # TF1.13 chains that stem from FBN op
        [BwdChainNode(op="ReadVariableOp"), BwdChainNode(op="VarHandleOp", name="beta")],
        [
            FwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="ReadVariableOp"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="VarHandleOp", name="beta"),
        ],
        # Frozen Graph
        [
            FwdChainNode(op="FusedBatchNorm"),
            BwdChainNode(op="Identity"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="Const", name="beta"),
        ],
        # Keras TF2 compatibilty
        [BwdChainNode(op="ReadVariableOp"), BwdChainNode(op="VarHandleOp", name="beta")],
    ]

    FBN_EPSILON_CHAINS = [
        [FwdChainNode(op="FusedBatchNorm")],
        [
            FwdChainNode(op="Mul"),
            FwdChainNode(op="Add"),
            FwdChainNode(op="Merge"),
            BwdChainNode(op="Switch"),
            BwdChainNode(op="FusedBatchNorm"),
        ],
        [
            BwdChainNode(op="Identity"),
            FwdChainNode(op="Switch"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Sub"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Rsqrt"),
            BwdChainNode(op="Add"),
            BwdChainNode(op="Const"),
        ],
    ]

    def __init__(self, node, logger=None):
        self._node = node
        self._use_mul_chains = False
        self._use_keras_if_chains = False
        self._logger = logger or default_logger()

        if node.op == "Mul" and test_scope(node.name, "batchnorm"):
            self._use_mul_chains = True
        elif node.op == "If":
            self._use_keras_if_chains = True
        elif node.op not in ["Switch", "FusedBatchNorm", "FusedBatchNormV3"]:
            raise BatchNormParsingError(node.name)

    def _get_bn_value_from_possible_chains(self, possible_chains, var_name):
        node = get_node_from_possible_chains(self._node.graph, self._node, possible_chains)
        if node is not None:
            if node.name in self._node.graph.values:
                return self._node.graph.values[node.name]
            elif node.op == "Const":
                return tf.make_ndarray(node._info.attr["value"].tensor)
        else:
            # fallback with edge case in frozen models - the chain can't describe the name of the const
            input_names = dict(zip(BATCH_NORM_INPUT_ORDER, self._node.get_inputs()))
            if (
                all(x in input_names for x in BATCH_NORM_INPUT_ORDER)
                and input_names[var_name] in self._node.graph.vertices_by_name
            ):
                var_node = look_for_node(
                    self._node.graph,
                    self._node.graph.get_vertex_by_name(input_names[var_name]),
                    [BwdChainNode(op="Const")],
                )
                if var_node is not None:
                    return tf.make_ndarray(var_node._info.attr["value"].tensor)

        raise BatchNormParsingError(self._node.name)

    def get_bn_gamma(self):
        try:
            possible_chains = type(self).FBN_GAMMA_CHAINS
            if self._use_mul_chains:
                possible_chains = type(self).MUL_GAMMA_CHAINS
            elif self._use_keras_if_chains:
                possible_chains = [[BwdChainNode(op="VarHandleOp", name="gamma")]]
            return self._get_bn_value_from_possible_chains(possible_chains, "gamma")
        except BatchNormParsingError:
            self._logger.debug(f"Gamma value not found from node {self._node.name}. Assumed gamma = 1.0.")
            return np.array(1.0, dtype=np.float32)

    def get_bn_moving_mean(self):
        possible_chains = type(self).FBN_MOVING_MEAN_CHAINS
        if self._use_mul_chains:
            possible_chains = type(self).MUL_MOVING_MEAN_CHAINS
        elif self._use_keras_if_chains:
            possible_chains = [[BwdChainNode(op="VarHandleOp", name="moving_mean")]]
        return self._get_bn_value_from_possible_chains(possible_chains, "moving_mean")

    def get_bn_moving_variance(self):
        possible_chains = type(self).FBN_MOVING_VAR_CHAINS
        if self._use_mul_chains:
            possible_chains = type(self).MUL_MOVING_VAR_CHAINS
        elif self._use_keras_if_chains:
            possible_chains = [[BwdChainNode(op="VarHandleOp", name="moving_var")]]
        return self._get_bn_value_from_possible_chains(possible_chains, "moving_var")

    def get_bn_beta(self):
        try:
            possible_chains = type(self).FBN_BETA_CHAINS
            if self._use_mul_chains:
                possible_chains = type(self).MUL_BETA_CHAINS
            elif self._use_keras_if_chains:
                possible_chains = [[BwdChainNode(op="VarHandleOp", name="beta")]]
            return self._get_bn_value_from_possible_chains(possible_chains, "beta")
        except BatchNormParsingError:
            self._logger.debug(f"Beta value not found from node {self._node.name}. Assumed beta = 0.0.")
            return np.array(0.0, dtype=np.float32)

    def get_bn_epsilon(self):
        if self._node.op in BN_OPS:
            return self._node._info.attr["epsilon"].f
        elif self._use_keras_if_chains:
            return DEFAULT_BN_EPSILON
        else:
            possible_chains = type(self).MUL_EPSILON_CHAINS if self._use_mul_chains else type(self).FBN_EPSILON_CHAINS
            node = get_node_from_possible_chains(self._node.graph, self._node, possible_chains)
            if node is not None:
                if self._use_mul_chains:
                    return np.array(node._info.attr["value"].tensor.float_val)
                else:
                    return node._info.attr["epsilon"].f
            else:
                raise BatchNormParsingError(self._node.name)


class TFGraphNode(NNGraphNode):
    def __init__(self, node_proto, graph, logger=None):
        super().__init__(node_proto, graph)
        self.name = node_proto.name
        self.op = node_proto.op
        self.input = node_proto.input
        self._logger = logger or default_logger()

    def get_const_val(self):
        tensor = self._info.attr["value"].tensor
        if len(tensor.float_val) > 0:
            return tensor.float_val[0]
        elif len(tensor.int_val) > 0:
            return tensor.int_val[0]

        raise UnsupportedModelError(f'Trying to get the const value of node "{self.name}" without float/int values')

    def get_pooling_ksize(self):
        pooling_ksize = self._info.attr["ksize"].list.i
        dims = [int(dim) for dim in pooling_ksize]
        if len(dims) == 4 and self.graph.is_nchw:
            dims = [dims[0], dims[2], dims[3], dims[1]]

        # the dims are [batch, height, width, channels]
        # only height and width pooling is supported
        if len(dims) != 4 or dims[0] != 1 or dims[3] != 1:
            raise UnsupportedPoolingLayerError(
                f"Unexpected pooling dims at {self.name}, dims={dims!s}",
            )
        return dims

    def get_input_vertices(self):
        preds = list(self._graph.predecessors(self))
        prefmsg = f"Start node {self.name} of type {self.op}"

        if self.is_input_op():
            return [self]

        # In case the graph starts with a concat node, we want to return it's inputs that are not the axis node
        elif self.op in CONCAT_OPS:
            num_of_var_nodes = len([pred for pred in preds if pred._is_var_layer()])
            if num_of_var_nodes != 1:
                raise CantFindGraphStartError(
                    f"{prefmsg} of concat type doesn't have 1 variable(axis) input, which is not supported.",
                )
            inputs_list = []
            for pred in preds:
                if not pred._is_var_layer():
                    inputs_list.append(pred)
            return inputs_list

        # This covers cases where start node has no input layers, or one input which is a variable/const
        if len(preds) == 0 or (len(preds) == 1 and preds[0]._is_var_layer()):
            raise CantFindGraphStartError(f"{prefmsg} has no input nodes, which is not supported.")

        # This will allow only cases where input layer is not a variable/const
        elif len(preds) == 1 and not preds[0]._is_var_layer():
            return [preds[0]]

        # This will choose between two inputs by elimination - *not* the variable/const node
        elif len(preds) >= 2:
            pred_is_var = [x._is_var_layer() for x in preds]
            pred_is_shape_info = [x._is_shape_info_layer() for x in preds]
            input_preds_indices = [i for i in range(len(preds)) if not pred_is_var[i] and not pred_is_shape_info[i]]
            if len(input_preds_indices) == 1:
                return [preds[input_preds_indices[0]]]
            elif len(input_preds_indices) == 2:
                if self.op in ADD_OPS:
                    return [preds[input_preds_indices[0]], preds[input_preds_indices[1]]]
                else:
                    raise CantFindGraphStartError(f"{prefmsg} has more than one input nodes, which is not supported.")
            else:
                raise CantFindGraphStartError(
                    f"{prefmsg} has only variables/consts/shapes as inputs, which is not defined.",
                )

        # The only possible case left is more than one input layer (which aren't variables/consts), such as concat
        else:
            raise CantFindGraphStartError(f"{prefmsg} has more than one input nodes, which is not supported.")

    def get_resize_info(self):
        # Extract resize info from previous vertices. If previous layer is const, then it contains the resize output shape.
        # If the previous is mul and const, then the const contains the upscale factors.
        # Returns a named tuple, containing the expected output shape or the upscale factors if is_upscale_factors is True.
        # If shape is unknown and the next vertex is ew_add, forced_by_unknown_shape is True.
        node = look_for_node(self._graph, self, [BwdChainNode(op="Const")])
        if node is not None:
            return ResizeInfo(tf.make_ndarray(node._info.attr["value"].tensor), False, False)
        const = look_for_node(self._graph, self, [BwdChainNode(op="Mul"), BwdChainNode(op="Const", name="Const")])
        if const is not None:
            const_value = tf.make_ndarray(const._info.attr["value"].tensor)
            const1 = look_for_node(
                self._graph,
                self,
                [BwdChainNode(op="Mul"), BwdChainNode(op="Const", name="Const_1")],
            )
            if const1 is not None:
                const1_value = tf.make_ndarray(const1._info.attr["value"].tensor)
                return ResizeInfo(const_value * const1_value, False, False)
            return ResizeInfo(const_value, False, True)
        else:
            output_shape = self.get_output_shapes()[0]
            unknown_width_height = output_shape[1:3] == [-1, -1] and output_shape[3] > 0
            if unknown_width_height:
                ewadd_candidate = look_for_node(self._graph, self, [FwdChainNode(op="Add")])
                if ewadd_candidate and ewadd_candidate.decide_which_add() == AddOpRole.ew:
                    return ResizeInfo(ewadd_candidate.get_output_shapes(), True, False)
                else:
                    raise UnsupportedResizeLayerError(
                        f"Resize layer {self.name} has unknown output shape, which is "
                        f"currently only supported when followed by EW-Add layer.",
                    )

            self._logger.debug(
                f"Resize layer {self.name} has explicit resize ratio instead of expected output shape, "
                "and will not be verified against actual output shape.",
            )
            return ResizeInfo(None, False, False)

    def is_1d_resize(self):
        output_shape = self.get_output_shapes()[0]
        if len(output_shape) != 3:
            return False

        split_axis_node = look_for_node(self._graph, self, [BwdChainNode(op="Const", name="split_dim")])
        concat_node = look_for_node(self._graph, self, [FwdChainNode(op="Concat")])
        concat_axis_node = look_for_node(self._graph, self, [FwdChainNode(op="Concat"), BwdChainNode(op="Const")])
        if not split_axis_node or not concat_node or not concat_axis_node:
            return False

        split_axis = split_axis_node._info.attr["value"].tensor.int_val[0]
        concat_axis = concat_axis_node._info.attr["value"].tensor.int_val[0]
        num_of_splits = self._info.attr["num_split"].i
        num_concat_inputs = concat_node._info.attr["N"].i

        return (
            split_axis == 1
            and concat_axis == 1
            and len(list(self._graph.successors(self))) == 1
            and (num_concat_inputs / num_of_splits) == int(num_concat_inputs / num_of_splits)
        )

    def get_1d_resize_info(self):
        concat_node = look_for_node(self._graph, self, [FwdChainNode(op="Concat")])
        output_shapes = concat_node.get_output_shapes()
        output_shapes = [[output_shape[0], 1, output_shape[1], output_shape[2]] for output_shape in output_shapes]
        num_of_splits = self._info.attr["num_split"].i
        num_concat_inputs = concat_node._info.attr["N"].i
        upscale_factors = (1, int(num_concat_inputs / num_of_splits))
        return upscale_factors, output_shapes

    def get_resize_bilinear_pixels_mode(self):
        if bool(self._info.attr["align_corners"].b) is True:
            return ResizeBilinearPixelsMode.align_corners
        elif bool(self._info.attr["half_pixel_centers"].b) is True:
            return ResizeBilinearPixelsMode.half_pixels
        return ResizeBilinearPixelsMode.disabled

    def get_strides(self):
        strides = [int(dim) for dim in self._info.attr["strides"].list.i]
        if len(strides) == 4 and self.graph.is_nchw:
            strides = [strides[0], strides[2], strides[3], strides[1]]
        return strides

    def get_dilations(self, is_dilations_s2b=False):
        if is_dilations_s2b:
            block_shape_chain = [BwdChainNode(op="Const", name="block_shape")]
            block_shape_vertex = look_and_validate(self._graph, self, block_shape_chain)
            dilation_rates = tf.make_ndarray(block_shape_vertex._info.attr["value"].tensor)
            return [1, dilation_rates[0], dilation_rates[1], 1]
        if self.op in ["Conv2D", "DepthwiseConv2dNative"]:
            return [int(dim) for dim in self._info.attr["dilations"].list.i]
        elif self.op in ["Conv2DBackpropInput"]:
            # not really supported in deconv, just to keep conv2d type ops consistent in parsing
            return [1, 1, 1, 1]
        else:
            raise UnsupportedModelError(f"Dilations cannot be found in node {self.name} of type {self.op}")

    def get_padding_from_op(self):
        padding = ensure_str(self._info.attr["padding"].s)
        if padding == "SAME":
            # SAME is tensorflow is different than our chosen SAME padding
            padding = PaddingType.same_tensorflow
        elif padding == "VALID":
            padding = PaddingType.valid
        else:
            raise UnsupportedPaddingError("Unsupported TF padding mode used: %s." % (padding))
        return padding

    def get_layer_var(self):
        possible_chains = [
            [BwdChainNode(op="Identity", name="read"), BwdChainNode(op="Variable")],
            [BwdChainNode(op="ReadVariableOp"), BwdChainNode(op="VarHandleOp")],
            [BwdChainNode(op="ExpandDims"), BwdChainNode(op="ReadVariableOp"), BwdChainNode(op="VarHandleOp")],
        ]
        possible_const_chains = [
            [BwdChainNode(op="Identity"), BwdChainNode(op="Const")],
            [BwdChainNode(op="Const")],
            [BwdChainNode(op="ExpandDims"), BwdChainNode(op="Identity"), BwdChainNode(op="Const")],
        ]
        var_vertex = get_node_from_possible_chains(self._graph, self, possible_chains)
        if var_vertex is None:
            const_vertex = get_node_from_possible_chains(self._graph, self, possible_const_chains)
            if const_vertex is not None:
                const_shape = list(tf.make_ndarray(const_vertex._info.attr["value"].tensor).shape)
                return const_vertex, const_shape
            raise TFVariableError(self.name)
        var_shape = [int(dim.size) for dim in var_vertex._info.attr["shape"].shape.dim]
        return var_vertex, var_shape

    def get_layer_var_data(self):
        var_vertex, var_shape = self.get_layer_var()
        if var_vertex.op == "Const":
            values = tf.make_ndarray(var_vertex._info.attr["value"].tensor)
        else:
            values = self._graph.values[var_vertex.name]
            for size_actual, size_expected in zip(values.shape, var_shape):
                if size_actual != size_expected:
                    raise TFVariableError(var_vertex.name)
        return copy.deepcopy(values), var_shape

    def _is_var_layer(self):
        is_const = "Const" in self.op
        is_identity_read = "Identity" in self.op and "read" in self.name
        has_variable_pred = look_for_node(self._graph, self, [BwdChainNode(op="Variable")]) is not None
        is_identity = "Identity" in self.op
        has_const_pred = look_for_node(self._graph, self, [BwdChainNode(op="Const")]) is not None

        # TF1.13 Upgrade: read variable can be an atomic op, instead of identity+var
        is_atomic_read_variable_op = "ReadVariableOp" in self.op
        return (
            ((is_identity_read and has_variable_pred) or (is_identity and has_const_pred))
            or is_const
            or is_atomic_read_variable_op
        )

    def _is_shape_info_layer(self):
        if self.op in ["Pack", "StridedSlice", "Shape"]:
            return True

        if self.op == "Mul":
            possible_chain = [BwdChainNode(op="StridedSlice"), BwdChainNode(op="Shape")]
            shape_node = look_for_node(self._graph, self, possible_chain)
            return shape_node is not None

        return False

    def is_input_op(self):
        return self.op in INPUT_OPS

    def get_iterator_get_next_multiple_output_io_indices(self):
        io_indices = []
        if self.op == "IteratorGetNext":
            for succ in self._graph.successors(self):
                relevant_inputs = [x for x in succ.input if self.name in x]
                if relevant_inputs:
                    input = relevant_inputs[0]  # expecting first occurrence anyway
                    io_index = 0 if ":" not in input else int(input.split(":")[-1])
                    io_indices.append(io_index)
        return io_indices

    def duplicate_iterator_get_next_by_io_indices(self, io_indices):
        new_duplicates = {}
        edges_to_remove = []
        for io_index in io_indices:
            if io_index > 0:
                new_vertex = TFGraphNode(self._info, self._graph)
                new_vertex.name = f"{self.name}:{io_index}"
                self._graph.add_vertex_by_name(new_vertex)

                for pred in self._graph.predecessors(self):
                    self._graph.add_edge(pred, new_vertex)
                for succ in self._graph.successors(self):
                    for i, inp in enumerate(succ.input):
                        if ":" in inp and self.name in inp and int(inp.split(":")[-1]) == io_index:
                            succ.input[i] = new_vertex.name
                            self._graph.add_edge(new_vertex, succ)
                            edges_to_remove.append([self, succ])

                new_duplicates[io_index] = new_vertex

        for edge in edges_to_remove:
            self._graph.remove_edge(edge[0], edge[1])

        return new_duplicates

    def get_output_shapes(self, possible_successors_names=None, **kwrags):
        io_indices = []
        possible_successors = None
        actual_successors = list(self._graph.successors(self))
        if possible_successors_names:
            possible_successors = [
                self._graph.vertices_by_name[x.name] for x in actual_successors if x.name in possible_successors_names
            ]

        if self.op in ["IteratorGetNext"] and not possible_successors:
            # Edge case: input generator may have many different types of structs in output, anything that can be found
            # in a tf record. Our approach here is to only allow it if the iterator feeds a "natural" graph start op.
            possible_successors = [x for x in actual_successors if x.op in GRAPH_START_OPS]
            if not possible_successors:
                return []

        if possible_successors:
            for succ in possible_successors:
                for inp in succ._info.input:
                    if inp == self.name:
                        io_indices.append(0)
                    elif f"{self.name}:" in inp:
                        io_indices.append(int(inp.split(":")[-1]))
        else:
            io_indices = [0]

        output_shapes = []
        possible_output_shapes = self._info.attr["_output_shapes"].list.shape
        for io_index in io_indices:
            if len(possible_output_shapes) > io_index:
                output_shapes.append([int(dim.size) for dim in possible_output_shapes[io_index].dim])

        if not any(output_shapes) and "element_shape" in self._info.attr:
            output_shapes = [[-1] + [dim.size for dim in self._info.attr["element_shape"].shape.dim]]

        if self.graph.is_nchw:
            output_shapes = [
                [output_shape[0], output_shape[2], output_shape[3], output_shape[1]]
                if len(output_shape) == 4
                else output_shape
                for output_shape in output_shapes
            ]

        return output_shapes

    def get_feature_split_info(self):
        if self.op not in SPLIT_OPS:
            raise UnsupportedModelError(f"Unexpected feature split node {self.name}")

        # verify that split axis is the features dimension
        possible_inputs = [x for x in list(self._graph.predecessors(self)) if x.op != "Const"]
        possible_input_shapes = [x.get_output_shapes(possible_successors_names=[self.name])[0] for x in possible_inputs]
        input_shape = possible_input_shapes[0]
        axis_chain = [BwdChainNode(op="Const", name="split_dim")]
        axis_node = look_and_validate(self._graph, self, axis_chain)
        axis = axis_node._info.attr["value"].tensor.int_val[0]
        rank2_cond = len(input_shape) == 2 and axis in [-1, 1]
        rank4_cond = len(input_shape) == 4 and axis in [-1, 3]
        if not rank2_cond and not rank4_cond:
            raise UnsupportedFeatureSplitterError(
                f"Features split is currently supported along last dimension "
                f"(features) only. Node {self.name} has axis {axis}.",
            )

        # verify that the sum of split sizes amounts to the total features in the output shape
        num_splits = self._info.attr["num_split"].i
        split_sizes_node = look_for_node(self._graph, self, [BwdChainNode(op="Const", name="Const")])
        if split_sizes_node is not None:
            split_sizes = [int(x) for x in tf.make_ndarray(split_sizes_node._info.attr["value"].tensor)]
        else:
            split_sizes = [int(input_shape[-1] / num_splits) for _ in range(num_splits)]
        if np.sum(split_sizes) != input_shape[-1]:
            raise UnsupportedFeatureSplitterError(
                f"Feature split node {self.name} must have an output shape where the features "
                "dimension is the sum of all splits sizes. "
                f"Found num_splits={num_splits} and split_sizes={split_sizes}.",
            )

        io_indices = sorted(self.get_vertex_successors_io_indices().items(), key=lambda item: item[1][0].name)
        io_indices = [item[0] for item in io_indices]

        # TODO: Add support of using more than one output (and not all) of split layer - SDK-30108
        # More than 1 split is used, but not all splits
        if len(io_indices) > 1 and len(io_indices) != num_splits:
            raise UnsupportedFeatureSplitterError(
                f"Feature split node {self.name} must have an output for each split, "
                f"or no more than 1 output. Found num_splits={num_splits} and "
                f"num_outputs={len(io_indices)}.",
            )

        # index of the beginning of the split for each output
        split_indices = [sum(split_sizes[:io_index]) for io_index in io_indices]
        features_split_dims = []
        output_shapes = []
        for io_index in io_indices:
            features_split_dims.append(split_sizes[io_index])
            output_shapes.append(input_shape[:-1] + [split_sizes[io_index]])
        return features_split_dims, split_indices, output_shapes

    def get_vertex_successors_io_indices(self):
        res = {}
        for succ in list(self._graph.successors(self)):
            succ_inputs = [x for x in succ._info.input if self.name == x or f"{self.name}:" in x]
            for input in succ_inputs:
                tokens = input.split(":")
                io_index = 0 if len(tokens) == 1 else int(tokens[-1])
                if io_index in res:
                    res[io_index].append(succ)
                else:
                    res[io_index] = [succ]
        return res

    def get_dilated_s2b_padding(self, dilations):
        if self.op not in ["SpaceToBatchND"]:
            raise UnsupportedModelError(f"Unsupported padding in dilated Conv2D node {self.name}")

        # verify that the middle layer in the s2b->b2s block has valid padding
        possible_convs = [[FwdChainNode(op="Conv2D")], [FwdChainNode(op="DepthwiseConv2dNative")]]
        for chain in possible_convs:
            node = look_for_node(self._graph, self, chain)
            if node is not None and ensure_str(node._info.attr["padding"].s) != "VALID":
                raise UnsupportedDilationError(f"Unsupported padding in dilated Conv2D node {node.name}")

        padding_node_chain = [BwdChainNode(op="Const", name="paddings")]
        padding_node = look_and_validate(self._graph, self, padding_node_chain)
        padding_value = tf.make_ndarray(padding_node._info.attr["value"].tensor)
        dh, dw = dilations[1:3]
        ph, pw = padding_value[0][0], padding_value[1][0]
        if ph == 0 and pw == 0:
            return PaddingType.valid
        elif dh == ph and dw == pw:
            return PaddingType.same_tensorflow
        else:
            raise UnsupportedDilationError(f"Unsupported padding in dilated Conv2D node {self.name}")

    def get_padding_from_const(self):
        if self.op not in ["Pad", "PadV2"]:
            raise UnsupportedModelError(f"Unexpected padding node {self.name}")

        # paddings const in regular pad op comes in a constant before the pad op itself
        padding_node = look_for_node(self._graph, self, [BwdChainNode(op="Const", name="paddings")])
        if padding_node is not None:
            padding_value = tf.make_ndarray(padding_node._info.attr["value"].tensor)
        else:
            # In resize_image_with_crop_or_pad, before the pad op there is a reshape, and before that 2 constants,
            # one holds the padding values and one holds it's shape
            padding_node_chain = [BwdChainNode(op="Reshape", name="Reshape"), BwdChainNode(op="Const", name="stack")]
            padding_node = look_and_validate(self._graph, self, padding_node_chain)

            shape_node_chain = [BwdChainNode(op="Reshape", name="Reshape"), BwdChainNode(op="Const", name="shape")]
            shape_node = look_and_validate(self._graph, self, shape_node_chain)

            padding_value = tf.make_ndarray(padding_node._info.attr["value"].tensor)
            shape_value = tf.make_ndarray(shape_node._info.attr["value"].tensor)
            padding_value = np.reshape(padding_value, shape_value)

        # TODO: improve validation using current layer input shape
        zeros = np.array([0, 0])
        if padding_value.shape != (4, 2) or not np.array_equal(padding_value[0, :], zeros):
            raise UnsupportedPaddingError(f"Unexpected padding value {padding_value} in node {self.name}")

        # handle padding const value: only accepted values are 0, -inf
        const_values_node = look_for_node(self._graph, self, [BwdChainNode(op="Const", name="constant_values")])
        const_value = (
            0 if const_values_node is None else tf.make_ndarray(const_values_node._info.attr["value"].tensor).tolist()
        )
        if const_value not in (0, -np.inf):
            raise UnsupportedPaddingError(f"Unsupported padding const_values {const_value} given in node {self.name}")
        elif const_value == -np.inf and self.is_before_conv(["Conv2D"]):
            raise UnsupportedPaddingError(
                f"Unsupported {const_value} padding const_values before Conv2D layer at node {self.name}",
            )
        elif const_value == 0 and not self.is_after_relu():
            self._logger.debug("Padding with const value of 0, in a node NOT after Relu is probably not cool.")

        return padding_value[1:].astype(int).flatten().tolist()

    def is_keras_relu6(self):
        possible_chains = [[FwdChainNode(op="Minimum"), FwdChainNode(op="Maximum")], [FwdChainNode(op="Minimum")]]
        node = get_node_from_possible_chains(self._graph, self, possible_chains)
        if node is None:
            return False

        min_node_possible_chains = [
            [FwdChainNode(op="Minimum"), BwdChainNode(op="Const")],
            [FwdChainNode(op="Minimum"), BwdChainNode(op="Cast"), BwdChainNode(op="Const")],
        ]
        min_const_node = get_node_from_possible_chains(self._graph, self, min_node_possible_chains)
        max_const_node = look_for_node(
            self._graph,
            self,
            [FwdChainNode(op="Minimum"), FwdChainNode(op="Maximum"), BwdChainNode(op="Const")],
        )

        min_const = min_const_node.get_const_val()
        max_const = max_const_node.get_const_val() if max_const_node is not None else 0.0
        if min_const == 6.0 and max_const == 0.0:
            return True
        return False

    def is_shuffle(self):
        reshape_shape = self.get_output_shapes()[0]
        if len(reshape_shape) != 5:
            return False
        chain = [FwdChainNode(op="Transpose"), FwdChainNode(op="Reshape")]
        return look_for_node(self._graph, self, chain) is not None

    def get_asymmetric_depth_to_space_params(self):
        chain = [FwdChainNode(op="Transpose"), FwdChainNode(op="Reshape")]
        reshape_node = look_for_node(self._graph, self, chain)

        if reshape_node is None and self.is_asymmetric_depth_to_space_single_reshape():
            reshape_node = self

        input_shape = None
        for shape in self.get_input_shapes():
            if len(shape) == 4:
                input_shape = shape
                break

        output_shapes = reshape_node.get_output_shapes()
        block_sizes = [output_shapes[0][1] // input_shape[1], output_shapes[0][2] // input_shape[2]]
        return block_sizes, output_shapes

    def is_asymmetric_depth_to_space_single_reshape(self):
        output_shape = self.get_output_shapes()[0]
        input_shape = self.get_input_shapes()[0]
        if not (len(input_shape) == len(output_shape) == 4):
            return False

        # support D2S variant with single Reshape vertex, where:
        #   1. (1,H,W,F)->(1,H,(2*W),F/2) is equivalent to D2S with (1,2) block size
        #   2. (1,H,1,F)->(1,(2*H),1,F/2) is equivalent to D2S with (2,1) block size, and W!=1 isn't supported
        same_batch_half_features = input_shape[0] == output_shape[0] and input_shape[3] == output_shape[3] * 2
        single_column = input_shape[2] == output_shape[2] == 1
        double_width = input_shape[2] * 2 == output_shape[2]
        double_height = input_shape[1] * 2 == output_shape[1] and single_column
        return same_batch_half_features and (double_width or double_height)

    def is_asymmetric_depth_to_space(self):
        if self.op not in ["Reshape"]:
            raise UnsupportedModelError(f"Unexpected reshape node {self.name}")

        chain = [FwdChainNode(op="Transpose")]
        transpose_node = look_for_node(self._graph, self, chain)
        chain = [FwdChainNode(op="Transpose"), FwdChainNode(op="Reshape")]
        reshape_node = look_for_node(self._graph, self, chain)

        if reshape_node is None and transpose_node is None:
            return self.is_asymmetric_depth_to_space_single_reshape()
        elif reshape_node is None or transpose_node is None:
            return False

        perm_node = look_for_node(self._graph, transpose_node, [BwdChainNode(op="Const")])
        perm_tensor = perm_node._info.attr["value"].tensor
        transpose_permutation = list(tf.make_ndarray(perm_tensor))
        if transpose_permutation != [0, 2, 1, 3]:
            return False

        input_shape = None
        for shape in self.get_input_shapes():
            if len(shape) == 4:
                input_shape = shape
                break
        if not input_shape:
            return False

        output_shape = reshape_node.get_output_shapes()[0]
        block_sizes = [output_shape[1] // input_shape[1], output_shape[2] // input_shape[2]]
        if input_shape[3] % (block_sizes[0] * block_sizes[1]) != 0:
            return False

        if input_shape is None or input_shape[0] <= 0:
            self._logger.debug(
                f"DepthToSpace layer near {self.name} has incomplete shapes validation due to unknown batch " "size",
            )

        first_reshape_output_shape = self.get_output_shapes()[0]
        if (
            (first_reshape_output_shape[0] != input_shape[0] * input_shape[1] and input_shape[0] > 0)
            or (first_reshape_output_shape[1] != input_shape[2])
            or (first_reshape_output_shape[2] != block_sizes[0])
            or (first_reshape_output_shape[3] != old_div(input_shape[3], block_sizes[0]))
        ):
            return False

        expected_output_shape = [
            input_shape[0],
            block_sizes[0] * input_shape[1],
            block_sizes[1] * input_shape[2],
            old_div(input_shape[3], (block_sizes[0] * block_sizes[1])),
        ]

        return expected_output_shape == output_shape

    def has_drop_rate(self):
        chain = [BwdChainNode(op="If", name="drop_rate")]
        return look_for_node(self._graph, self, chain) is not None

    def is_after_relu(self):
        possible_chains = [[BwdChainNode(op="Relu")]]
        return any(look_for_node(self._graph, self, chain) is not None for chain in possible_chains)

    def is_before_slice_layer(self):
        possible_chains = [[FwdChainNode(op="Slice")]]
        return any(look_for_node(self._graph, self, chain) is not None for chain in possible_chains)

    def is_nchw(self):
        return "data_format" in self._info.attr and ensure_str(self._info.attr["data_format"].s) == "NCHW"

    def is_bn_in_training(self):
        if "is_training" in self._info.attr:
            return self._info.attr["is_training"].b
        return False

    def is_resize_nearest_reshape(self):
        if self.op not in ["Reshape"]:
            raise UnsupportedModelError(f"Unexpected reshape node {self.name}")

        chain = [FwdChainNode(op="Mul")]
        mul = look_for_node(self._graph, self, chain)
        if not mul:
            return False

        chain = [BwdChainNode(op="Const")]
        mul_const = look_for_node(self._graph, mul, chain)
        if not mul_const:
            return False

        chain = [FwdChainNode(op="Reshape")]
        reshape = look_for_node(self._graph, mul, chain)
        if not reshape:
            return False

        first_shape = self.get_output_shapes()[0]
        mul_vals = tf.make_ndarray(mul_const._info.attr["value"].tensor)
        mul_shape = mul_vals.shape
        second_shape = reshape.get_output_shapes()[0]

        # supporting only rank 4, with the multiplication structure that implements resize NN
        first_shape_cond = len(first_shape) == 6 and first_shape[2] == 1 and first_shape[4] == 1
        mul_vals_cond = np.all(mul_vals == 1)
        mul_shape_cond = (
            len(mul_shape) == 6 and mul_shape[0] == 1 and mul_shape[1] == 1 and mul_shape[3] == 1 and mul_shape[5] == 1
        )
        second_shape_cond = (
            len(second_shape) == 4
            and first_shape[1] * mul_shape[2] == second_shape[1]
            and first_shape[3] * mul_shape[4] == second_shape[2]
        )

        return first_shape_cond and mul_vals_cond and mul_shape_cond and second_shape_cond

    def is_resize_nearest_mul(self):
        if self.op != "Mul":
            return False
        pred = next(iter(self._graph.predecessors(self)))
        return pred.op == "Reshape" and pred.is_resize_nearest_reshape()

    def get_reshape_as_resize_nearest_info(self):
        chain = [FwdChainNode(op="Mul"), FwdChainNode(op="Reshape")]
        mul, reshape = get_all_nodes_in_chain(self._graph, self, chain)

        chain = [BwdChainNode(op="Const")]
        mul_const = look_and_validate(self._graph, mul, chain)
        mul_shape = tf.make_ndarray(mul_const._info.attr["value"].tensor).shape
        return [mul_shape[2], mul_shape[4]], reshape.get_output_shapes()

    def validate_global_avg_pool(self):
        if self.op != "Mean":
            raise UnsupportedModelError(f"Unexpected global average pool node {self.name}")

        input_shape = self.get_input_shapes()[0]
        output_shape = self.get_output_shapes()[0]

        indices_node_chain = [BwdChainNode(op="Const")]
        indices_node = look_and_validate(self._graph, self, indices_node_chain)
        axis = tf.make_ndarray(indices_node._info.attr["value"].tensor)

        # axes are just on width, but height is 1
        valid_height = np.array_equal(axis, np.array([2])) and input_shape[1] == 1
        # axes are just on height, but width is 1
        valid_width = np.array_equal(axis, np.array([1])) and input_shape[2] == 1

        unsupported_axis = not np.array_equal(axis, np.array([1, 2])) and not valid_height and not valid_width
        # keep_dims=False is ok if the correct axis are used
        unsupported_keep_dims = len(input_shape) != len(output_shape) and unsupported_axis

        if unsupported_axis or unsupported_keep_dims:
            err_msg = f"Reduce mean layer {self.name} has "
            if unsupported_axis:
                err_msg += f"unsupported axis {axis} (must be over spatial dimensions only), "
            if unsupported_keep_dims:
                if unsupported_axis:
                    err_msg += "and "
                err_msg += "unsupported keep_dims=False, "
            err_msg += "must be equivalent to global average pool."
            raise UnsupportedReduceMeanLayerError(err_msg)

    def is_global_max_pool(self):
        if self.op != "Max":
            return False

        indices_node_chain = [BwdChainNode(op="Const", name="reduction_indices")]
        indices_node = look_and_validate(self._graph, self, indices_node_chain)
        indices_value = tf.make_ndarray(indices_node._info.attr["value"].tensor)
        expected_indices_value = np.array([1, 2])
        return np.array_equal(indices_value, expected_indices_value)

    def is_before_conv(self, conv_ops):
        # TODO: add deconv and depthwise conv support?
        # TODO: what about add that has two output nodes?
        return any(look_for_node(self._graph, self, [FwdChainNode(op=op)]) is not None for op in conv_ops)

    def is_before_reshape_to_dense(self):
        possible_chains = [
            [FwdChainNode(op="Reshape"), FwdChainNode(op="MatMul")],
            [FwdChainNode(op="Reshape"), FwdChainNode(op="Identity"), FwdChainNode(op="MatMul")],
        ]
        return get_node_from_possible_chains(self._graph, self, possible_chains) is not None

    def is_null_transpose(self):
        ones_in_output_shape = 0
        output_shape = self.get_output_shapes()[0]
        rank = len(output_shape)
        for dim in output_shape[1:]:
            if dim == 1:
                ones_in_output_shape = ones_in_output_shape + 1
        return self.op == "Transpose" and ones_in_output_shape >= rank - 2

    def is_features_reshape(self):
        output_shape = self.get_output_shapes()[0]
        input_shape = self.get_input_shapes()[0]
        if not (len(input_shape) == len(output_shape) == 4):
            return False
        return (
            input_shape[0:2] == output_shape[0:2]
            and input_shape[2] == 1
            and input_shape[3] == output_shape[2] * output_shape[3]
        )

    def is_width_features_transpose(self):
        if self.op != "Transpose":
            return False

        perm_node = look_for_node(self._graph, self, [BwdChainNode(op="Const")])
        perm_tensor = perm_node._info.attr["value"].tensor
        transpose_permutation = list(tf.make_ndarray(perm_tensor))
        return transpose_permutation == [0, 1, 3, 2]

    def is_flat_to_frames_reshape(self):
        output_shape = self.get_output_shapes()[0]
        input_shape = self.get_input_shapes()[0]
        if len(input_shape) != 2:
            return False
        if len(output_shape) == 3:
            return output_shape[1] * output_shape[2] == input_shape[1]
        elif len(output_shape) == 4:
            return (output_shape[2] * output_shape[3] == input_shape[1]) and output_shape[1] == 1
        else:
            return False

    def has_batchnorm_in_vertex_name(self, liberal_naming=False):
        # Possible patterns are:
        # 1. One of the sub scopes name begins with batch norm name.
        # 2. One of the sub scopes name ends with batch norm name.
        # 3. One of the sub scopes names is "bn".
        # 4. One of the sub scopes contain the word batchnorm or batch_norm.
        # 5. One of the sub scopes contain the combo '-bn-'.
        # 6. One of the sub scopes names ends with a delimited "bn" (_/- prefix, possible digits suffix).
        bn_pattern = (
            r"((\S*\/)*{}\S*\Z)|"
            r"((\S*\/)*\S*{}(\/\S*)*\Z)|"
            r"((\S*\/)+bn(\/\S*)*)|"
            r"((\S*\/)*\S*batch_*norm\S*(\/\S*)*\Z)|"
            r"((\S*\/)*\S*-bn-\S*(\/\S*)*\Z)|"
            r"((\S*\/)*\S*([_-]+)bn([0-9]*)(\/\S*)*\Z)"
        )

        basic_result = any(re.match(bn_pattern.format(x, x), self.name.lower()) for x in BATCH_NORM_NAMES)
        liberal_result = self.name.startswith("bn") if liberal_naming else False
        return basic_result or liberal_result

    def is_conv_batchnorm(self):
        # Batchnorm op created in frozen graphs, consists of a "Conv2D" op and an "Add" op
        if self.op not in ADD_OPS:
            return False
        conv_node = look_for_node(self._graph, self, [BwdChainNode(op="Conv2D")])
        const_node = look_for_node(self._graph, self, [BwdChainNode(op="Const")])
        return const_node is not None and conv_node is not None and conv_node.has_batchnorm_in_vertex_name()

    def get_bn_info(self):
        if self.op == "Mul":
            add_node = look_for_node(self._graph, self, [FwdChainNode(op="Add")])
            mul_const_node = look_for_node(self._graph, self, [BwdChainNode(op="Const")])
            add_const_node = look_for_node(self._graph, self, [FwdChainNode(op="Add"), BwdChainNode(op="Const")])
            if add_node and mul_const_node and add_const_node:
                return BatchNormValues(
                    moving_mean=np.array(0.0, dtype=np.float32),
                    moving_variance=np.array(1.0, dtype=np.float32),
                    beta=tf.make_ndarray(add_const_node._info.attr["value"].tensor),
                    gamma=tf.make_ndarray(mul_const_node._info.attr["value"].tensor),
                    epsilon=np.array(0.0, dtype=np.float32),
                )

        bn_parser = BNParser(self)
        return BatchNormValues(
            moving_mean=bn_parser.get_bn_moving_mean(),
            moving_variance=bn_parser.get_bn_moving_variance(),
            beta=bn_parser.get_bn_beta(),
            gamma=bn_parser.get_bn_gamma(),
            epsilon=bn_parser.get_bn_epsilon(),
        )

    def get_input_shapes(self):
        input_shapes = []
        prev_node_names = self.input
        if prev_node_names is None:
            return []
        for node_name in prev_node_names:
            if node_name.split(":")[0] in self.graph.vertices_by_name:
                node = self._graph.get_vertex_by_name(node_name.split(":")[0])
                prev_node_output_shape = node.get_output_shapes(possible_successors_names=[self.name])[0]
                input_shapes.append(prev_node_output_shape)

        return input_shapes

    def get_depth_to_space_block_size(self):
        if self.op != "DepthToSpace":
            raise UnexpectedNodeError(
                f"Block size is only supported in DepthToSpace ops, {self.name} is not compatible.",
            )
        return self._info.attr["block_size"].i

    def get_space_to_depth_block_size(self):
        return self._info.attr["block_size"].i

    def get_inputs(self):
        return self.input

    def decide_which_add(self):
        if self.op not in ADD_OPS:
            raise UnexpectedNodeError(f"Tried to parse node {self.name} as an Add node, but its op is {self.op}")

        preds = list(self._graph.predecessors(self))
        # in case of ew_add(x, x) we create normalization of 2*x and ignore this func return value
        if self.is_mul_by_2_ew_add():
            preds = preds * 2
        preds_outs = [pred.get_output_shapes(possible_successors_names=[self.name])[0] for pred in preds]
        if len(preds) != 2:
            raise UnsupportedAddLayerError("Add node %s has %d != 2 inputs" % (self.name, len(preds)))

        # try I - is it elementwise add? check if preds have same shape as my shape
        ranks = [len(x) for x in preds_outs]
        out_shapes = self.get_output_shapes()[0]
        ranks_equal = ranks[0] == ranks[1] in [2, 4]
        shapes_equal = preds_outs[0] == preds_outs[1] == out_shapes
        end_index = 3 if ranks[0] == 4 else 1
        first_shape_unknown = (
            preds_outs[0][1:end_index] == [-1, -1]
            and preds_outs[1] == out_shapes
            and preds_outs[0][end_index] == preds_outs[1][end_index]
        )
        second_shape_unknown = (
            preds_outs[1][1:end_index] == [-1, -1]
            and preds_outs[0] == out_shapes
            and preds_outs[0][end_index] == preds_outs[1][end_index]
        )

        broadcast = ranks_equal and is_spatial_broadcast(preds_outs[0], preds_outs[1], is_two_sided=True)
        if ranks_equal and (shapes_equal or first_shape_unknown or second_shape_unknown or broadcast):
            if first_shape_unknown or second_shape_unknown:
                self._logger.debug(
                    f"EW add layer {self.name} is allowing an input with an unknown shape (width/height), "
                    "and is assuming it matches the other input shape",
                )
            return AddOpRole.ew

        # try II - is it bias add? check if it reads from a variable and shapes make sense
        var_ops_chains = [[BwdChainNode(op="Identity", name="read")], [BwdChainNode(op="ReadVariableOp")]]
        const_ops_chain = [BwdChainNode(op="Const")]
        bias_node = get_node_from_possible_chains(self._graph, self, var_ops_chains)
        if bias_node is None:
            bias_node = look_for_node(self._graph, self, const_ops_chain)
            if bias_node is None:
                raise UnsupportedAddLayerError(
                    f"Tried to parse Add node {self.name} as bias, but its input is not a "
                    f"variable read node or const",
                )

        bias_shape, real_input_shape = preds_outs if bias_node == preds[0] else preds_outs[::-1]
        if (len(bias_shape) == 1 and bias_shape[0] == real_input_shape[-1]) or (
            len(bias_shape) == 2 and bias_shape[0] == 1 and bias_shape[1] == real_input_shape[-1]
        ):
            return AddOpRole.bias
        elif not bias_shape or bias_shape[0] == 1:
            return AddOpRole.normalization

        raise UnsupportedAddLayerError(f"Tried to parse Add node {self.name} as bias, but the shapes don't match")

    def _get_leaky_alpha_from_max(self):
        chain = [BwdChainNode(op="Mul"), BwdChainNode(op="Const", name="alpha")]
        return look_for_node(self._graph, self, chain)

    def is_leaky_max(self):
        return self._get_leaky_alpha_from_max() is not None

    def is_leaky_mul(self):
        chains = [
            [BwdChainNode(op="Const", name="alpha")],
            [FwdChainNode(op="Maximum")],
        ]
        return all(look_for_node(self._graph, self, chain) is not None for chain in chains)

    def is_keras_prelu(self):
        chain = None
        if self.op == "Relu":
            chain = [FwdChainNode(op="Add"), BwdChainNode(op="Mul"), BwdChainNode(op="Relu"), BwdChainNode(op="Neg")]
        elif self.op == "Neg":
            chain = [FwdChainNode(op="Relu"), FwdChainNode(op="Mul"), FwdChainNode(op="Add"), BwdChainNode(op="Relu")]
        elif self.op in ADD_OPS:
            chain = [BwdChainNode(op="Mul"), BwdChainNode(op="Relu"), BwdChainNode(op="Neg")]
        else:
            return False
        return look_for_node(self._graph, self, chain) is not None

    def is_hardsigmoid(self):
        if self.op != "Mul":
            return False

        return (
            look_for_node(
                self._graph,
                self,
                [FwdChainNode(op="Add"), FwdChainNode(op="Minimum"), FwdChainNode(op="Maximum")],
            )
            is not None
        )

    def get_leaky_alpha(self, is_leaky_max=False):
        # TF1.13 Upgrade: Backward compatibility fix, TF1.7 LeakyRelu was built from mul+max
        if is_leaky_max:
            alpha = self._get_leaky_alpha_from_max()
            if alpha is None:
                raise CantFindLeakyAlphaError(self.name)
            return alpha._info.attr["value"].tensor.float_val[0]
        return self._info.attr["alpha"].f

    def get_prelu_slope(self):
        chain = [
            FwdChainNode(op="Add"),
            BwdChainNode(op="Mul"),
            BwdChainNode(op="Neg"),
            BwdChainNode(op="ReadVariableOp"),
            BwdChainNode(op="VarHandleOp", name="alpha"),
        ]

        alpha_node = look_for_node(self._graph, self, chain)
        if alpha_node.name in self.graph.values:
            return np.array(self._graph.values[alpha_node.name], dtype=float).flatten()

        raise CantFindPReluSlopeError(self.name)

    def get_hardsigmoid_info(self):
        alpha_value = look_for_node(self._graph, self, [BwdChainNode(op="Const")])
        if not alpha_value:
            raise CantFindSigmoidParametersError(self.name)
        alpha_value = alpha_value._info.attr["value"].tensor.float_val[0]

        beta_value = look_for_node(self._graph, self, [FwdChainNode(op="Add"), BwdChainNode(op="Const")])
        if not beta_value:
            raise CantFindSigmoidParametersError(self.name)
        beta_value = beta_value._info.attr["value"].tensor.float_val[0]

        return alpha_value, beta_value

    def get_min_max_info(self):
        min_value, max_value = None, None
        possible_chains = [
            [
                BwdChainNode(op="ExpandDims"),
                BwdChainNode(op="ExpandDims"),
                BwdChainNode(op="ExpandDims"),
                BwdChainNode(op="Const"),
            ],
            [BwdChainNode(op="ExpandDims"), BwdChainNode(op="ExpandDims"), BwdChainNode(op="Const")],
            [BwdChainNode(op="ExpandDims"), BwdChainNode(op="Const")],
            [BwdChainNode(op="Const")],
        ]

        if self.op == "Minimum":
            max_value_node = get_node_from_possible_chains(self._graph, self, possible_chains)
            if max_value_node:
                max_value = max_value_node.get_const_val()

                maximum_node = look_for_node(self._graph, self, [FwdChainNode(op="Maximum")])
                if maximum_node:
                    # structure of (min) -> (max) perform by clipping of [min_value, max_value]
                    min_value_node = get_node_from_possible_chains(self._graph, maximum_node, possible_chains)
                    if not min_value_node:
                        return None, None
                    min_value = min_value_node.get_const_val()
                else:
                    # structure of (min) perform by clipping of [-np.inf, max_value]
                    min_value = -np.inf
        elif self.op == "Maximum":
            min_value_node = get_node_from_possible_chains(self._graph, self, possible_chains)
            if min_value_node:
                min_value = min_value_node.get_const_val()

                minimum_node = look_for_node(self._graph, self, [FwdChainNode(op="Minimum")])
                if minimum_node:
                    # structure of (max) -> (min) perform by clipping of [min_value, max_value]
                    max_value_node = get_node_from_possible_chains(self._graph, minimum_node, possible_chains)
                    if not max_value_node:
                        return None, None
                    max_value = max_value_node.get_const_val()
                else:
                    # structure of (max) perform by clipping of [min_value, np.inf]
                    max_value = np.inf

        return min_value, max_value

    def is_resize_image_with_crop_or_pad(self):
        node = look_for_node(
            self._graph,
            self,
            [BwdChainNode(op="Identity", name="control_dependency"), BwdChainNode(op="Slice", name="Slice")],
        )
        return node is not None

    def is_1d(self):
        if self.op != "ExpandDims":
            return False

        dim_node = look_for_node(self._graph, self, [BwdChainNode(op="Const", name="dim")])
        if not dim_node:
            return False

        conv_chain = [FwdChainNode(op="Conv2D"), FwdChainNode(op="Squeeze")]
        maxpool_chain = [FwdChainNode(op="MaxPool"), FwdChainNode(op="Squeeze")]
        conv_squeeze_node = look_for_node(self._graph, self, conv_chain)
        maxpool_squeeze_node = look_for_node(self._graph, self, maxpool_chain)
        expanded_dim = tf.make_ndarray(dim_node._info.attr["value"].tensor)
        if conv_squeeze_node:
            squeezed_dim = conv_squeeze_node._info.attr["squeeze_dims"].list.i[0]
            if expanded_dim not in [1, -3] or squeezed_dim not in [1, -3]:
                return False
        elif maxpool_squeeze_node:
            squeezed_dim = maxpool_squeeze_node._info.attr["squeeze_dims"].list.i[0]
            if expanded_dim != 2 or squeezed_dim != 2:
                return False
        else:
            return False

        return True

    def is_1d_maxpool(self):
        if self.op != "MaxPool":
            return False
        dim_node = look_for_node(
            self._graph,
            self,
            [BwdChainNode(op="ExpandDims"), BwdChainNode(op="Const", name="dim")],
        )
        squeeze_node = look_for_node(self._graph, self, [FwdChainNode(op="Squeeze")])
        if (
            not dim_node
            or tf.make_ndarray(dim_node._info.attr["value"].tensor) != 2
            or not squeeze_node
            or squeeze_node._info.attr["squeeze_dims"].list.i[0] != 2
        ):
            return False
        return True

    def is_threshold_activation(self):
        if self.op == "Mul":
            node = look_for_node(
                self._graph,
                self,
                [BwdChainNode(op="Cast"), BwdChainNode(op="Greater")],
                exact_match=True,
            )
        else:
            node = look_for_node(self._graph, self, [FwdChainNode(op="Cast"), FwdChainNode(op="Mul")], exact_match=True)
        return node is not None

    def get_threshold(self):
        node = look_for_node(
            self._graph,
            self,
            [BwdChainNode(op="Cast"), BwdChainNode(op="Greater"), BwdChainNode(op="Const")],
        )
        if not node:
            raise CantFindThresholdError(self.name)
        return node._info.attr["value"].tensor.float_val[0]

    def is_biased_delta_activation(self):
        if self.op == "Sign":
            node = look_for_node(self._graph, self, [FwdChainNode(op="Abs"), FwdChainNode(op="Mul")])
        elif self.op == "Abs":
            node = look_for_node(self._graph, self, [FwdChainNode(op="Sign"), FwdChainNode(op="Mul")])
        elif self.op == "Mul":
            possible_chains = [
                [BwdChainNode(op="Abs"), BwdChainNode(op="Sign")],
                [BwdChainNode(op="Sign"), BwdChainNode(op="Abs")],
            ]
            node = get_node_from_possible_chains(self._graph, self, possible_chains)
        else:
            return False
        return node is not None

    def is_gelu_activation(self):
        if self.op == "RealDiv":
            mul_const = look_for_node(
                self._graph,
                self,
                [
                    FwdChainNode(op="Erf"),
                    FwdChainNode(op="Add"),
                    FwdChainNode(op="Mul"),
                    BwdChainNode(op="Mul"),
                    BwdChainNode(op="Const"),
                ],
            )
            return mul_const is not None

        if self.op == "Mul":
            div_const = look_for_node(
                self._graph,
                self,
                [
                    FwdChainNode(op="Mul"),
                    BwdChainNode(op="Add"),
                    BwdChainNode(op="Erf"),
                    BwdChainNode(op="RealDiv"),
                    BwdChainNode(op="Const"),
                ],
            )
            if div_const is None:
                # checks if it is the second mul node in gelu chain
                mul_const = look_for_node(self._graph, self, [BwdChainNode(op="Mul"), BwdChainNode(op="Const")])

                div_const = look_for_node(
                    self._graph,
                    self,
                    [
                        BwdChainNode(op="Add"),
                        BwdChainNode(op="Erf"),
                        BwdChainNode(op="RealDiv"),
                        BwdChainNode(op="Const"),
                    ],
                )
                return all(node is not None for node in [mul_const, div_const])
            else:
                return True

        if self.op in ADD_OPS:
            mul_const = look_for_node(
                self._graph,
                self,
                [FwdChainNode(op="Mul"), BwdChainNode(op="Mul"), BwdChainNode(op="Const")],
            )

            div_const = look_for_node(
                self._graph,
                self,
                [BwdChainNode(op="Erf"), BwdChainNode(op="RealDiv"), BwdChainNode(op="Const")],
            )

            return all(node is not None for node in [mul_const, div_const])

        if self.op == "Erf":
            mul_const = look_for_node(
                self._graph,
                self,
                [FwdChainNode(op="Add"), FwdChainNode(op="Mul"), BwdChainNode(op="Mul"), BwdChainNode(op="Const")],
            )

            div_const = look_for_node(self._graph, self, [BwdChainNode(op="RealDiv")])
            return all(node is not None for node in [mul_const, div_const])

        return False

    def should_skip_gelu(self):
        return self.op == "Mul" and look_for_node(
            self._graph,
            self,
            [
                FwdChainNode(op="Mul"),
                BwdChainNode(op="Add"),
                BwdChainNode(op="Erf"),
                BwdChainNode(op="RealDiv"),
                BwdChainNode(op="Const"),
            ],
        )

    def is_l2_normalization_square(self):
        if self.op == "Square":
            square_const = look_for_node(
                self._graph,
                self,
                [FwdChainNode(op="Sum"), FwdChainNode(op="Maximum"), FwdChainNode(op="Rsqrt"), FwdChainNode(op="Mul")],
            )
            return square_const is not None

    def is_silu_activation(self):
        if self.op == "Sigmoid":
            identity_n = look_for_node(self._graph, self, [FwdChainNode(op="Mul"), FwdChainNode(op="IdentityN")])
            return identity_n is not None

        elif self.op == "Mul":
            identity_n = look_for_node(self._graph, self, [FwdChainNode(op="IdentityN")])
            sigmoid = look_for_node(self._graph, self, [BwdChainNode(op="Sigmoid")])
            return identity_n is not None and sigmoid is not None

        elif self.op == "IdentityN":
            sigmoid = look_for_node(self._graph, self, [BwdChainNode(op="Mul"), BwdChainNode(op="Sigmoid")])
            return sigmoid is not None

        return False

    def is_swish_activation_first_mul(self):
        if self.op == "Mul":
            node = look_for_node(self._graph, self, [FwdChainNode(op="Sigmoid"), FwdChainNode(op="Mul")])
            return node is not None

        return False

    def is_swish_activation_second_mul(self):
        if self.op == "Mul":
            node = look_for_node(
                self._graph,
                self,
                [BwdChainNode(op="Sigmoid"), BwdChainNode(op="Mul"), BwdChainNode(op="Const")],
            )
            return node is not None

        return False

    def get_swish_beta(self):
        node = look_for_node(
            self._graph,
            self,
            [BwdChainNode(op="Sigmoid"), BwdChainNode(op="Mul"), BwdChainNode(op="Const")],
        )

        if not node:
            raise CantFindSwishBetaError(self.name)

        return node._info.attr["value"].tensor.float_val[0]

    def is_mish_activation(self):
        if self.op == "Softplus":
            mish_vertices = get_all_nodes_in_chain(self._graph, self, [FwdChainNode(op="Tanh"), FwdChainNode(op="Mul")])
            if mish_vertices is not None:
                return True

        elif self.op == "Mul":
            softplus = look_for_node(self._graph, self, [BwdChainNode(op="Tanh"), BwdChainNode(op="Softplus")])
            if softplus:
                softplus_preds = self._graph.predecessors(softplus)
                common_stem = any(x in softplus_preds for x in self._graph.predecessors(self))
                return common_stem and self.is_ew_mult()

        return False

    def is_hardswish_activation(self):
        if self.op in ADD_OPS:
            hardswish_vertices = get_all_nodes_in_chain(
                self._graph,
                self,
                [FwdChainNode(op="Relu6"), FwdChainNode(op="RealDiv"), FwdChainNode(op="Mul")],
            )
            if not hardswish_vertices:
                return False
            add_node = self
            clip_node, div_node, mul_node = hardswish_vertices
        elif self.op == "Mul":
            hardswish_vertices = get_all_nodes_in_chain(
                self._graph,
                self,
                [BwdChainNode(op="RealDiv"), BwdChainNode(op="Relu6"), BwdChainNode(op="Add")],
            )
            if not hardswish_vertices:
                return False
            mul_node = self
            div_node, clip_node, add_node = hardswish_vertices
        else:
            return False

        common_stem = any(x in self._graph.predecessors(add_node) for x in self._graph.predecessors(mul_node))

        add_mean, add_std, _ = add_node.get_normalization_info()
        is_add_3 = all(x == -3.0 for x in add_mean) and all(x == 1.0 for x in add_std)

        div_mean, div_std, _ = div_node.get_normalization_info()
        is_div_6 = all(x == 0.0 for x in div_mean) and all(x == 6.0 for x in div_std)

        if common_stem and mul_node.is_ew_mult() and is_add_3 and is_div_6:
            return True

        return False

    def is_ew_mult(self):
        if self.op == "Mul":
            preds = list(self._graph.predecessors(self))
            return len(preds) == 2 and not any(x.op in VAR_OPS for x in preds)

        return False

    def is_ew_div(self):
        if self.op == "RealDiv":
            preds = list(self._graph.predecessors(self))
            return len(preds) == 2 and not any(x.op in VAR_OPS for x in preds)

        return False

    def is_ew_sub(self):
        if self.op == "Sub":
            preds = list(self._graph.predecessors(self))
            return len(preds) == 2 and not any(x.op in VAR_OPS for x in preds)

        return False

    def is_normalization(self):
        # Possible cases (always assume start from the first encountered operator):
        # 1. Add/Sub -> Mul/Div: normalization of the form (x-Mean(x))/(Std(x))
        # 2. Mul/Div -> Add/Sub: equivalent normalization of the form (x-Mean(x)*Std(x))/(Std(x))
        # 3. Add/Sub: normalization with Std(x)=1 (private case Neg has Std(x)=-1)
        # 4. Mul/Div: normalization with Mean(x)=0
        # Note that for this predicate, 3+4 actually covers 1+2 respectively

        if len(self.get_output_shapes()[0]) == 2:
            return False

        if self.op in [*ADD_OPS, "Sub", "Neg"]:
            if self.op in ["Neg"]:
                return True
            const_preds = [pred for pred in self._graph.predecessors(self) if pred.op == "Const"]
            if len(const_preds) == 1:
                return True
            return False

        elif self.op in ["Mul", "RealDiv"]:
            preds = list(self._graph.predecessors(self))
            succs = list(self._graph.successors(self))
            const_preds = [pred for pred in preds if pred.op == "Const"]
            if len(preds) != 2 or len(const_preds) != 1 or len(succs) > 1:
                return False
            return True
        return False

    def is_multi_vertex_normalization(self):
        if not self.is_normalization():
            return False

        if len(list(self._graph.successors(self))) != 1:
            return False

        if self.op == "Neg":
            return False

        elif self.op in [*ADD_OPS, "Sub"]:
            mul_node = look_for_node(self._graph, self, [FwdChainNode(op="Mul")])
            div_node = look_for_node(self._graph, self, [FwdChainNode(op="RealDiv")])
            std_node = None
            std_const_node = None
            if mul_node:
                std_node = mul_node
                std_const_node = look_for_node(self._graph, mul_node, [BwdChainNode(op="Const")])
            elif div_node:
                std_node = div_node
                std_const_node = look_for_node(self._graph, div_node, [BwdChainNode(op="Const")])
            return std_node is not None and std_const_node is not None

        elif self.op in ["Mul", "RealDiv"]:
            sub_node = look_for_node(self._graph, self, [FwdChainNode(op="Sub")])
            add_node = get_node_from_possible_chains(
                self._graph,
                self,
                [[FwdChainNode(op="Add")], [FwdChainNode(op="AddV2")]],
            )
            mean_node = None
            mean_const_node = None
            if sub_node:
                mean_node = sub_node
                mean_const_node = look_for_node(self._graph, sub_node, [BwdChainNode(op="Const")])
            elif add_node:
                mean_node = add_node
                mean_const_node = look_for_node(self._graph, add_node, [BwdChainNode(op="Const")])
            return mean_node is not None and mean_const_node is not None

        return False

    def get_normalization_info(self):
        # Possible cases (always assume start from the first encountered operator):
        # 1. Add/Sub -> Mul/Div: normalization of the form (x-Mean(x))/(Std(x))
        # 2. Mul/Div -> Add/Sub: equivalent normalization of the form (x-Mean(x)*Std(x))/(Std(x))
        # 3. Add/Sub: normalization with Std(x)=1 (private case Neg has Std(x)=-1)
        # 4. Mul/Div: normalization with Mean(x)=0
        const_preds = [pred for pred in self._graph.predecessors(self) if pred.op == "Const"]
        if self.op in [*ADD_OPS, "Sub"]:
            # covering case #3
            raw_mean = list(const_preds[0]._info.attr["value"].tensor.float_val)
            if not raw_mean:
                raw_mean = list(np.fromstring(const_preds[0]._info.attr["value"].tensor.tensor_content, np.float32))
            # either add -mean or subtract mean
            mean = [x * -1.0 for x in raw_mean] if self.op in ADD_OPS else raw_mean

            # covering case #1
            std_node = None
            std_const_node = None
            mul_node = look_for_node(self._graph, self, [FwdChainNode(op="Mul")])
            div_node = look_for_node(self._graph, self, [FwdChainNode(op="RealDiv")])
            if mul_node:
                std_node = mul_node
                std_const_node = look_for_node(self._graph, mul_node, [BwdChainNode(op="Const")])
            elif div_node:
                std_node = div_node
                std_const_node = look_for_node(self._graph, div_node, [BwdChainNode(op="Const")])

            # either multiply by 1/std or divide by std
            raw_std = list(std_const_node._info.attr["value"].tensor.float_val) if std_const_node else [1.0]
            if not raw_std:
                raw_std = list(np.fromstring(std_const_node._info.attr["value"].tensor.tensor_content, np.float32))
            std = [1 / x for x in raw_std] if std_node is mul_node else raw_std

            return mean, std, std_node.get_output_shapes() if std_node else self.get_output_shapes()

        elif self.op in ["Neg"]:
            # can also be seen as case #3
            return [0.0], [-1.0], self.get_output_shapes()

        elif self.op in ["Mul", "RealDiv"]:
            # covering case #4
            std_node = const_preds[0] if const_preds else None
            if not std_node:
                raise UnsupportedNormalizationLayerError(
                    f"Could not find std values in normalization starting from node {self.name}.",
                )

            # covering case #2: multiply mean by std
            mean_node = None
            mean_const_node = None
            sub_node = look_for_node(self._graph, self, [FwdChainNode(op="Sub")])
            add_node = get_node_from_possible_chains(
                self._graph,
                self,
                [[FwdChainNode(op="Add")], [FwdChainNode(op="AddV2")]],
            )
            if sub_node:
                mean_node = sub_node
                mean_const_node = look_for_node(self._graph, sub_node, [BwdChainNode(op="Const")])
            elif add_node:
                mean_node = add_node
                mean_const_node = look_for_node(self._graph, add_node, [BwdChainNode(op="Const")])

            # either multiply by 1/std or divide by std (except for a mul-by-const scenario)
            raw_std = list(std_node._info.attr["value"].tensor.float_val)
            if not raw_std:
                raw_std = list(np.fromstring(std_node._info.attr["value"].tensor.tensor_content, np.float32))
            if self.op == "Mul":
                raw_std = np.array(raw_std)
                raw_std = np.reciprocal(raw_std)
                raw_std[raw_std == np.inf] = 0.0
                raw_std.tolist()
            std = raw_std

            # either add -mean or subtract mean
            raw_mean = list(mean_const_node._info.attr["value"].tensor.float_val) if mean_const_node else [0.0]
            if not raw_mean:
                raw_mean = list(np.fromstring(mean_const_node._info.attr["value"].tensor.tensor_content, np.float32))
            raw_mean = [x * -1.0 for x in raw_mean] if mean_node is add_node else raw_mean
            mean = [x * y for x, y in zip(raw_mean, std)]

            return mean, std, mean_node.get_output_shapes() if mean_node else self.get_output_shapes()

        raise UnsupportedNormalizationLayerError(
            f"Could not find std/mean values in normalization starting from node {self.name}.",
        )

    def get_delta_bias_value(self):
        if self.op == "Sign":
            const_node = look_for_node(
                self._graph,
                self,
                [FwdChainNode(op="Abs"), FwdChainNode(op="Mul"), BwdChainNode(op="Const")],
            )
        elif self.op == "Abs":
            const_node = look_for_node(
                self._graph,
                self,
                [FwdChainNode(op="Sign"), FwdChainNode(op="Mul"), BwdChainNode(op="Const")],
            )
        else:
            return None
        if not const_node:
            raise UnsupportedModelError(
                f"Could not find delta bias value in biased_delta activation starting from node {self.name}.",
            )
        return const_node._info.attr["value"].tensor.float_val[0]

    def is_valid_reduce_max_min(self):
        axis_node = look_for_node(self._graph, self, [BwdChainNode(op="Const")])
        if not axis_node:
            return False
        axes = axis_node._info.attr["value"].tensor.int_val
        if len(axes) > 1:
            return False
        axis = axes[0]
        return self._info.attr["keep_dims"].b and (axis in (-1, 3))

    def is_valid_reduce_sum(self):
        axis_node = look_for_node(self._graph, self, [BwdChainNode(op="Const")])
        if not axis_node:
            return False, []
        axes = self.get_axes_value(axis_node)
        return self._info.attr["keep_dims"].b and 0 not in axes, axes

    def is_valid_reduce_l2(self):
        sum_node = look_for_node(self._graph, self, [FwdChainNode(op="Sum")])
        if not sum_node:
            return False, []

        return sum_node.is_valid_reduce_sum()

    def get_sum_axes(self):
        sum_node = look_for_node(
            self._graph,
            self,
            [BwdChainNode(op="Rsqrt"), BwdChainNode(op="Maximum"), BwdChainNode(op="Sum")],
        )

        axis_node = look_for_node(self._graph, sum_node, [BwdChainNode(op="Const")])
        if not axis_node:
            return []
        return self.get_axes_value(axis_node)

    def get_axes_value(self, axis_node):
        if not axis_node:
            return []
        axes = axis_node._info.attr["value"].tensor.int_val
        if not axes:
            axes = list(axis_node._info.attr["value"].tensor.tensor_content)
            axes = [axes[4 * i] for i in range(len(axes) // 4)]
        return [4 + axis if axis < 0 else axis for axis in axes]

    def is_reduce_l2(self):
        if not self.is_square():
            return False

        node = look_for_node(self.graph, self, [FwdChainNode(op="Sum"), FwdChainNode(op="Sqrt")])
        return node is not None

    def is_l2_normalization(self):
        if not "Mul":
            return False

        node = look_for_node(
            self.graph,
            self,
            [BwdChainNode(op="Rsqrt"), BwdChainNode(op="Maximum"), BwdChainNode(op="Sum"), BwdChainNode(op="Square")],
        )
        return node is not None

    def is_square(self):
        if self.op not in SQUARE_OPS:
            return False

        # pow(1.0) is skipped, other cases are validated when creating the layer
        if self.op == "Pow" and self.get_power() == 1.0:
            return False

        if self.op == "Mul" and (len(self._info.input) != 2 or self._info.input[0] != self._info.input[1]):
            return False

        return True

    def is_mul_by_2_ew_add(self):
        if self.op not in ADD_OPS:
            return False
        preds = list(self._graph.predecessors(self))
        return (
            len(preds) == 1
            and len(self.input) == 2
            and self.input[0] == self.input[1]
            and preds[0].name == self.input[0]
        )

    def is_space_to_depth(self):
        # This function now allows only space to depth with block size == 2
        # We expect to see 4 pairs of slices, where each pair has a different [height_slice_start, width_slice_start],
        # and overall to see these pairs: [[0,0], [0,1], [1,0], [1,1]]
        slice_op = self
        if self.op in CONCAT_OPS:
            slice_op = next(iter(self._graph.predecessors(self)))
        if slice_op.op not in SLICE_OPS:
            return False
        pred = next(x for x in self._graph.predecessors(slice_op) if x.op != "Const")
        input_shape = pred.get_output_shapes()[0]
        if len(input_shape) != 4:
            return False
        input_height, input_width, input_features = input_shape[1:]
        successors = list(self._graph.successors(pred))
        slice_start_pairs = []
        concat_node = None
        for succ in successors:
            if succ.op not in SLICE_OPS:
                return False
            slice_values = succ.get_slices_values(allow_stride=True)
            next_successors = list(self._graph.successors(succ))
            if (
                len(next_successors) != 1
                or next_successors[0].op not in CONCAT_OPS
                or (concat_node and concat_node != next_successors[0])
            ):
                return False
            if (
                slice_values[0][1] not in [0, input_height]
                or slice_values[0][2] != 2
                or slice_values[1][1] not in [0, input_width]
                or slice_values[1][2] != 2
                or slice_values[2][0] != 0
                or slice_values[2][1] not in [0, input_features]
                or slice_values[2][2] != 1
            ):
                return False
            slice_start_pairs.append([slice_values[0][1], slice_values[1][0]])
        return len(slice_start_pairs) == 4 and all(
            pair in [[0, 0], [0, 1], [1, 0], [1, 1]] for pair in slice_start_pairs
        )

    def is_1x1_einsum(self):
        if self.op not in EINSUM_OPS:
            return False

        preds = list(self._graph.predecessors(self))
        if len(preds) != 2:
            return False

        var_preds = [pred for pred in preds if pred._is_var_layer()]
        equation = str(self._info.attr["equation"].s, "utf-8")
        return len(var_preds) == 1 and equation in ["abc,cde->abde", "abcd,cde->abe"]

    def is_matmul_einsum(self):
        if self.op not in EINSUM_OPS:
            return False

        preds = list(self._graph.predecessors(self))
        if len(preds) != 2:
            return False

        non_const_var_preds = [pred for pred in preds if not pred._is_var_layer() and pred.op != "Const"]
        equation = str(self._info.attr["equation"].s, "utf-8")
        return len(non_const_var_preds) == 2 and equation in ["aecd,abcd->acbe", "acbe,aecd->abcd"]

    def get_einsum_1x1_info(self):
        kernel, _ = self.get_layer_var_data()

        equation = str(self._info.attr["equation"].s, "utf-8")
        if equation == "abc,cde->abde":
            kernel = kernel.reshape([1, 1, kernel.shape[0], kernel.shape[2]])
            return kernel, kernel.shape
        elif equation == "abcd,cde->abe":
            kernel = kernel.reshape([1, 1, kernel.shape[1], kernel.shape[2]])
            return kernel, kernel.shape

        raise UnsupportedConvLayerError(f"Node {self.name} could not be converted to a convolution layer")

    def should_transpose_einsum_matmul_input(self):
        equation = str(self._info.attr["equation"].s, "utf-8")
        return equation == "aecd,abcd->acbe"

    def is_spatial_flatten_reshape(self):
        input_shape = self.get_input_shapes()[0]
        output_shape = self.get_output_shapes()[0]
        return (
            len(input_shape) == 4
            and len(output_shape) == 3
            and output_shape == [input_shape[0], input_shape[1] * input_shape[2], input_shape[3]]
        )

    def is_spatial_concat(self, rank4_dim, rank3_dim):
        if self.op not in CONCAT_OPS:
            return False

        concat_axis_node = look_for_node(self._graph, self, [BwdChainNode(op="Const")])
        if concat_axis_node:
            axis = concat_axis_node._info.attr["value"].tensor.int_val[0]
            if axis == 0:
                raise UnsupportedConcatLayerError(f"Concat over the batch dimension is not supported in {self.name}.")
            output_shape = self.get_output_shapes()[0]
            if (len(output_shape) == 3 and (axis in (rank3_dim, rank3_dim - 3))) or (
                len(output_shape) == 4 and (axis in (rank4_dim, rank4_dim - 4))
            ):
                return True

        return False

    def is_spatial_h_concat(self):
        # rank 3 can't be spatial h
        return self.is_spatial_concat(1, np.inf)

    def is_spatial_w_concat(self):
        return self.is_spatial_concat(2, 1)

    def get_concat_axis(self):
        axis = DEFAULT_CONCAT_AXIS
        if self.is_spatial_w_concat():
            axis = ConcatAxis.spatial_w
        elif self.is_spatial_h_concat():
            axis = ConcatAxis.spatial_h
        return axis

    def get_slices_values(self, allow_stride=False):
        if self.op in ["StridedSlice", "Slice"]:
            const_preds = sorted(
                [x for x in self._graph.predecessors(self) if x.op == "Const"],
                key=lambda node: node.name,
            )
            # Number of dimensions in the slicing op
            if self.op == "StridedSlice":
                stack = const_preds[0]
                stack_1 = const_preds[1]

                new_axis_mask = self._info.attr["new_axis_mask"].i
                shrink_axis_mask = self._info.attr["shrink_axis_mask"].i
                if new_axis_mask != 0 or shrink_axis_mask != 0:
                    raise UnsupportedSliceLayerError(
                        f"Found new axis or shrink axis in slice node {self.name}, which is not supported",
                    )

                begin_mask_val = self._info.attr["begin_mask"].i
                begin_mask = [int(d) for d in str(bin(begin_mask_val)[2:])][::-1]
                begin_mask = np.append(begin_mask, (4 - len(begin_mask)) * [0]).astype(np.int32)

                end_mask_val = self._info.attr["end_mask"].i
                end_mask = [int(d) for d in str(bin(end_mask_val)[2:])][::-1]
                end_mask = np.append(end_mask, (4 - len(end_mask)) * [0]).astype(np.int32)

                ellipsis_str = str(bin(self._info.attr["ellipsis_mask"].i))[:1:-1]
                ellipsis = -1 if "1" not in ellipsis_str else ellipsis_str.index("1")

                stack_2 = const_preds[2]
                # Fetch slicing values from the const, and extend the start, stop and step values for unsliced dims(pad with 0 to avoid slicing)
                start_val = tf.make_ndarray(stack._info.attr["value"].tensor)
                stop_val = tf.make_ndarray(stack_1._info.attr["value"].tensor)
                step_val = tf.make_ndarray(stack_2._info.attr["value"].tensor)
                if len(start_val) > 4:
                    raise UnsupportedModelError(
                        f"Illegal slice in node {self.name} - rank of slice is larger than 4 dimensions.",
                    )

                start, stop, step = [0] * 4, [0] * 4, [1] * 4
                dim_index = 0
                for i in range(len(start_val)):
                    # If ellipsis is used in this dim, need to find how many more dims are specified to be sliced,
                    # and skip those that ignored due to the ellipsis
                    if i == ellipsis:
                        dims_to_skip = 4 - len(start_val)
                        dim_index += dims_to_skip + 1
                        continue
                    # If begin_mask[i] == 1 then the start value for dim i is 0
                    if begin_mask[i] == 0:
                        start[dim_index] = start_val[i]
                    # If end_mask[i] == 1 then the end value for dim i is 0(later changed to output_shape[i])
                    if end_mask[i] == 0:
                        stop[dim_index] = stop_val[i]
                    step[dim_index] = step_val[i]
                    dim_index += 1

                start = start[1:]
                stop = stop[1:]
                step = step[1:]
                if any(x > 1 or x < 0 for x in step[1:]) and not allow_stride:
                    raise UnsupportedSliceLayerError(
                        f"Slices with stride > 1 or stride < 0 in width or features axis "
                        f"in node {self.name} are not supported.",
                    )
                if step[0] < 0:
                    raise UnsupportedSliceLayerError(
                        f"Slices with stride < 0 in height axis in node {self.name} are not supported.",
                    )
            else:
                # Slice op
                stack = const_preds[0]
                if len(const_preds) > 1:
                    stack_1 = const_preds[1]
                    crop_size = np.append(tf.make_ndarray(stack_1._info.attr["value"].tensor)[1:3], [0]).astype(
                        np.int32,
                    )
                else:
                    pack_node = look_for_node(self.graph, self, [BwdChainNode(op="Pack")])
                    if pack_node:
                        pack_consts = sorted(
                            [x for x in pack_node._graph.predecessors(pack_node) if x.op == "Const"],
                            key=lambda node: node.name,
                        )
                        consts = [tf.make_ndarray(x._info.attr["value"].tensor).tolist() for x in pack_consts]
                        crop_size = np.append(consts, [0]).astype(np.int32)

                start = np.append(tf.make_ndarray(stack._info.attr["value"].tensor)[1:3], [0]).astype(np.int32)
                stop = [sum(x) if x[1] != -1 else 0 for x in zip(start, crop_size)]
                step = [1] * 3

            (height_slices, width_slices, feature_slices) = (
                [start[0], stop[0], step[0]],
                [start[1], stop[1], step[1]],
                [start[2], stop[2], step[2]],
            )

            return height_slices, width_slices, feature_slices
        else:
            raise UnsupportedModelError(f"Slices cannot be found in node {self.name} of type {self.op}")

    def get_shuffle_last_reshape(self):
        return look_for_node(self._graph, self, [FwdChainNode(op="Transpose"), FwdChainNode(op="Reshape")])

    def get_mult_scalar(self):
        if self.op != "Mul":
            raise UnexpectedNodeError(f"Multiply by constant vertex {self.name} is not of type Mul.")

        scalar_node = look_for_node(self._graph, self, [BwdChainNode(op="Const")])
        if not scalar_node:
            raise UnexpectedNodeError(f"Multiply by constant vertex {self.name} has no scalar input.")

        return scalar_node._info.attr["value"].tensor.float_val[0]

    def get_power(self):
        if self.op != "Pow":
            raise UnexpectedNodeError(f"Pow vertex not found near {self.name}")

        pow_node = look_for_node(self._graph, self, [BwdChainNode(op="Const")])
        if not pow_node:
            raise UnexpectedNodeError(f"Pow vertex {self.name} has no power input.")

        return pow_node._info.attr["value"].tensor.float_val[0]

    def get_activation_less_values(self):
        const_node = look_for_node(self._graph, self, [BwdChainNode(op="Const")])
        if const_node:
            return np.squeeze(tf.make_ndarray(const_node._info.attr["value"].tensor), axis=0)

        raise UnsupportedActivationLayerError(f"Unable to find values to compare for Less activation {self.name}.")

    @property
    def in_valid_subgraph(self):
        return self._in_valid_subgraph

    @in_valid_subgraph.setter
    def in_valid_subgraph(self, val):
        self._logger.debug(f"Marked vertex {self.name} in_valid_subgraph={val}")
        self._in_valid_subgraph = val

    def is_inv_pos_activation(self):
        if self.op not in DIV_OPS:
            return False

        if self.op == "Reciprocal":
            return True

        const_index = -1
        const_preds = [pred for pred in self._graph.predecessors(self) if pred.op == "Const"]
        if len(const_preds) == 1:
            const_index = next(i for i, name in enumerate(self.input) if const_preds[0].name in name)
            if const_preds[0]._info.attr["value"].tensor.float_val[0] == 1 and const_index == 0:
                return True
        return False


class TFGraph(NNGraph):
    def __init__(self, raw_graph_proto, values, logger=None):
        super().__init__(raw_graph_proto, values)
        self._is_nchw = False
        self._logger = logger or default_logger()

        tf_graph_def = self._raw_proto.as_graph_def(add_shapes=True)
        for node in tf_graph_def.node:
            vertex = TFGraphNode(node, self)
            self.add_node(vertex)
            self.add_vertex_by_name(vertex)
            if vertex.is_nchw():
                self._is_nchw = True

        for node in tf_graph_def.node:
            for input_node in node.input:
                input_node_src = input_node.split(":")[0]
                if input_node_src not in self._vertices_by_name:
                    continue
                self.add_edge(self._vertices_by_name[input_node_src], self._vertices_by_name[node.name])

    @property
    def is_nchw(self):
        return self._is_nchw

    def visualize(self, filename_prefix):
        if not SDKPaths().has_graphviz:
            self._logger.warning("Cannot visualize TF graph because graphviz in unavailable")
            return
        good_names = [
            "BatchNorm",
            "MaxPool",
            "Conv2D",
            "BiasAdd",
            "Relu",
            "Relu6",
            "Elu",
            "Sigmoid",
            "Identity",
            "Add",
            "MatMul",
            "Reshape",
            "Pad",
            "Concat",
            "DepthwiseConv2dNative",
            "ResizeBilinear",
            "ResizeNearestNeighbor",
            "AvgPool",
            "Split",
            "Slice",
            "StridedSlice",
        ]
        good_endings = [
            "batchnorm/add_1",
            "batchnorm/mul_1",
            "/convolution",
            "v1/add",
            "batchnorm/mul",
            "batchnorm/add",
        ]
        bad_names = ["gradient", "Adam", "/cond/", "paddings"]
        other_graph = nx.DiGraph()
        for node in self.nodes:
            has_good_name = any(good_name in node.op for good_name in good_names)
            has_good_ending = any(node.name.endswith(good_ending) for good_ending in good_endings)
            has_bad_name = any(bad_name in node.name for bad_name in bad_names)
            if (has_good_name or has_good_ending) and not has_bad_name:
                other_graph.add_node(node)
        for src_node, dst_node in self.edges:
            if src_node in other_graph.nodes and dst_node in other_graph.nodes:
                other_graph.add_edge(src_node, dst_node)

        dot_path = f"{filename_prefix}.dot"
        svg_path = f"{filename_prefix}.svg"
        nx.drawing.nx_agraph.write_dot(other_graph, dot_path)
        os.system(f'dot -Tsvg "{dot_path}" -o "{svg_path}"')
