#!/usr/bin/env python
import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    CONCAT_DIM_TO_AXIS,
    DEFAULT_CONCAT_AXIS,
    ConcatAxis,
)
from hailo_sdk_client.model_translator.exceptions import (
    UnsupportedActivationLayerError,
    UnsupportedConstInputError,
    UnsupportedConvLayerError,
    UnsupportedEqualLayerError,
    UnsupportedFormatConversionLayerError,
    UnsupportedModelError,
    UnsupportedPoolingLayerError,
    UnsupportedReduceL2LayerError,
    UnsupportedReduceMaxLayerError,
    UnsupportedReduceMinLayerError,
    UnsupportedReduceSumLayerError,
    UnsupportedShuffleLayerError,
    UnsupportedSpaceToDepthLayerError,
    UnsupportedTileLayerError,
)
from hailo_sdk_client.model_translator.tflite_translator.tflite_graph import (
    CUSTOM_SIGN_OPS,
    PACK_OPS,
    SUPPORTED_FUSED_ACTIVATIONS,
)
from hailo_sdk_common.hailo_nn.hn_definitions import (
    ActivationType,
    DepthToSpaceType,
    FeatureMultiplierType,
    FormatConversionType,
    LayerType,
    NormalizationType,
    PaddingType,
    SpaceToDepthType,
)
from hailo_sdk_common.hailo_nn.hn_layers import (
    ActivationLayer,
    ArgmaxLayer,
    ConcatLayer,
    ConstInputLayer,
    Conv2DLayer,
    DenseLayer,
    DepthToSpaceLayer,
    EqualLayer,
    EWAddLayer,
    EWAddNLayer,
    EWDivLayer,
    EWMaxLayer,
    EWMinLayer,
    EWMultLayer,
    EWSubLayer,
    ExternalPadLayer,
    FeatureMultiplierLayer,
    FeatureShuffleLayer,
    FeatureSplitterLayer,
    FormatConversionLayer,
    LayerNormalizationLayer,
    NormalizationLayer,
    NullLayer,
    PoolingLayer,
    ReduceL2Layer,
    ReduceMaxLayer,
    ReduceMinLayer,
    ReduceSumLayer,
    ResizeLayer,
    SliceLayer,
    SoftmaxLayer,
    SpaceToDepthLayer,
)
from hailo_sdk_common.tools.models_translator_helper import FUSED_ACTIVATION_SUFFIX


def create_layer_from_vertex(layer_type, vertex, **kwargs):
    if layer_type == LayerType.conv:
        return _create_convolutional_layer(vertex)
    elif layer_type == LayerType.base_activation:
        return create_activation_layer(vertex)
    elif layer_type == "pool":
        return _create_pooling_layer(vertex)
    elif layer_type == LayerType.dense:
        return _create_dense_layer(vertex)
    elif layer_type == LayerType.base_ew_add:
        return _create_ew_add_layer(vertex)
    elif layer_type == LayerType.base_ew_add_n:
        return _create_ew_add_n_layer(vertex)
    elif layer_type == LayerType.base_ew_sub:
        return _create_ew_sub_layer(vertex)
    elif layer_type == LayerType.base_slice:
        return _create_slice_layer(vertex)
    elif layer_type == LayerType.concat:
        return _create_concat_layer(vertex)
    elif layer_type == LayerType.resize:
        return _create_resize_layer(vertex)
    elif layer_type == LayerType.depth_to_space:
        return _create_depth_to_space_layer(vertex)
    elif layer_type == LayerType.format_conversion:
        return _create_format_conversion_layer(vertex)
    elif layer_type == LayerType.ew_mult:
        return _create_ew_mult_layer(vertex)
    elif layer_type == LayerType.ew_div:
        return _create_ew_div_layer(vertex)
    elif layer_type == LayerType.ew_max:
        return _create_ew_max_layer(vertex)
    elif layer_type == LayerType.ew_min:
        return _create_ew_min_layer(vertex)
    elif layer_type == LayerType.normalization:
        return _create_normalization_layer(vertex)
    elif layer_type == LayerType.space_to_depth:
        return _create_space_to_depth_layer(vertex, **kwargs)
    elif layer_type == LayerType.external_pad:
        return _create_external_pad_layer(vertex)
    elif layer_type == LayerType.reduce_sum:
        return _create_reduce_sum_layer(vertex)
    elif layer_type == LayerType.feature_multiplier:
        return _create_square_layer(vertex)
    elif layer_type == LayerType.reduce_l2:
        return _create_reduce_l2_layer(vertex)
    elif layer_type == LayerType.l2_normalization:
        return _create_l2_normalization_layer(vertex)
    elif layer_type == LayerType.softmax:
        return _create_softmax_layer(vertex)
    elif layer_type == LayerType.argmax:
        return _create_argmax_layer(vertex)
    elif layer_type == LayerType.feature_splitter:
        return _create_feature_split_layer(vertex)
    elif layer_type == LayerType.null:
        return _create_null_layer(vertex)
    elif layer_type == LayerType.feature_shuffle:
        return _create_shuffle_layer(vertex)
    elif layer_type == LayerType.reduce_max:
        return _create_reduce_max_layer(vertex)
    elif layer_type == LayerType.reduce_min:
        return _create_reduce_min_layer(vertex)
    elif layer_type == LayerType.equal:
        return _create_equal_layer(vertex)
    elif layer_type == LayerType.tile:
        return _create_tile_layer(vertex)
    else:
        raise UnsupportedModelError(f"Unknown layer type {layer_type} for vertex {vertex.name} (op {vertex.op})")


def _create_convolutional_layer(vertex):
    consumed_vertices = []
    fused_activation = None
    dynamic_kernel_shape = None
    conv_attrs, vertex_kernel, bias = vertex.get_conv_info()
    groups = 1
    if vertex.op == "CONV_2D":
        op = LayerType.base_conv
        # [f_out, k_h, k_w, f_in] (tflite repr) -> [k_h, k_w, f_in, f_out]
        kernel = np.transpose(vertex_kernel, [1, 2, 3, 0])
    elif vertex.op == "TRANSPOSE_CONV":
        op = LayerType.base_deconv
        # [f_out, k_h, k_w, f_in] (tflite repr) -> [k_h, k_w, f_in, f_out]
        kernel = np.transpose(vertex_kernel, [1, 2, 3, 0])
    elif vertex.op == "DEPTHWISE_CONV_2D":
        op = LayerType.base_dw
        if vertex_kernel is None:
            kernel = None
            dynamic_kernel_shape = vertex.get_dynamic_kernel_shape()
            if not dynamic_kernel_shape:
                raise UnsupportedModelError(f"Cannot find kernel in vertex {vertex.name}")
        else:
            depth_multiplier = conv_attrs["depth_multiplier"]
            # [1, k_h, k_w, f_in * depth_multiplier] (tflite repr) -> [k_h, k_w, f_in, depth_multiplier]
            _, k_h, k_w, icxdm = vertex_kernel.shape
            f_in = vertex.input_shapes[0][-1]
            if depth_multiplier > 1:
                op = LayerType.base_conv
                groups = f_in
                kernel = np.reshape(vertex_kernel, [k_h, k_w, 1, f_in * depth_multiplier])
            else:
                kernel = np.reshape(vertex_kernel, [k_h, k_w, f_in, depth_multiplier])

    else:
        raise UnsupportedConvLayerError(f"Unexpected convolutional layer at {vertex.name}, op={vertex.op}")

    layer = Conv2DLayer.create(
        original_name=vertex.name,
        input_vertex_order=vertex.input,
        op=op,
        kernel=kernel,
        bias=bias,
        padding=conv_attrs["padding"],
        padding_vals=None,
        strides=conv_attrs["strides"],
        dilations=conv_attrs["dilations"],
        groups=groups,
        output_shapes=vertex.output_shapes,
        dynamic_kernel_shape=dynamic_kernel_shape,
    )

    if conv_attrs["fused_activation"] in SUPPORTED_FUSED_ACTIVATIONS:
        fused_activation = conv_attrs["fused_activation"]

    return layer, consumed_vertices, fused_activation


def create_activation_layer(vertex, is_fused_activation=False, op=None):
    consumed_vertices = []
    activation = None
    leaky_alpha = None
    prelu_slope = None
    activation_values = None
    activation_threshold = None
    swish_beta = None
    hardsigmoid_alpha = None
    hardsigmoid_beta = None
    clip_min = None
    clip_max = None
    delta = None

    op = op if is_fused_activation else vertex.op
    if op == "RELU":
        activation = ActivationType.relu
    elif op == "RELU6" or (op == "MINIMUM" and vertex.is_relu6_clip()):
        activation = ActivationType.relu6
    elif op == "TANH":
        activation = ActivationType.tanh
    elif op == "EXP":
        is_softplus_activation, consumed_vertices = vertex.is_softplus_activation()
        activation = ActivationType.softplus if is_softplus_activation else ActivationType.exp
    elif op == "PRELU":
        vertex_prelu_slope = vertex.get_prelu_slope()
        if np.any(vertex_prelu_slope != vertex_prelu_slope[0]):
            activation = ActivationType.prelu
            prelu_slope = vertex_prelu_slope
        else:
            activation = ActivationType.leaky
            leaky_alpha = vertex_prelu_slope[0]
    elif op == "ELU":
        activation = ActivationType.elu
    elif op == "LEAKY_RELU":
        vertex_leaky_alpha = vertex.get_leaky_alpha()
        if vertex_leaky_alpha < 0:
            activation = ActivationType.prelu
            prelu_slope = [vertex_leaky_alpha]
        else:
            activation = ActivationType.leaky
            leaky_alpha = vertex_leaky_alpha
    elif op == "HARD_SWISH":
        activation = ActivationType.hardswish
    elif op == "LOGISTIC":
        activation = ActivationType.sigmoid
    elif op == "SQRT":
        activation = ActivationType.sqrt
    elif op == "LOG":
        activation = ActivationType.log
    elif op == "MUL" and vertex.is_threshold_activation():
        activation = ActivationType.threshold
        activation_threshold, consumed_vertices = vertex.get_threshold_activation_values()
    elif op == "MUL" and vertex.is_mish_activation():
        consumed_vertices = vertex.get_mish_activation_vertices()
        activation = ActivationType.mish
    elif op == "MUL" and vertex.is_silu_activation():
        consumed_vertices = vertex.get_silu_activation_vertices()
        activation = ActivationType.silu
    elif op == "MUL" and vertex.is_swish_activation_second_mul():
        swish_beta, consumed_vertices = vertex.get_swish_beta()
        activation = ActivationType.swish
    elif op in ("LESS", "GREATER"):
        activation, activation_values, consumed_vertices = vertex.get_activation_less_or_greater_values()
    elif op == "MINIMUM" and vertex.is_hardsigmoid():
        activation = ActivationType.hardsigmoid
        hardsigmoid_alpha, hardsigmoid_beta, consumed_vertices = vertex.get_hardsigmoid_info()
    elif op in ("MINIMUM", "MAXIMUM"):
        activation = ActivationType.clip
        clip_min, clip_max, consumed_vertices = vertex.get_min_max_info()
    elif op in [*CUSTOM_SIGN_OPS, "ABS"] and vertex.is_biased_delta_activation():
        activation = ActivationType.biased_delta
        delta, consumed_vertices = vertex.get_biased_delta()
    elif op == "ABS":
        activation = ActivationType.prelu
        prelu_slope = [-1.0]
    elif op == "DIV" and vertex.is_inv_pos_activation():
        activation = ActivationType.inv_pos
    elif op == "GELU":
        activation = ActivationType.gelu
    elif op == "RSQRT":
        activation = ActivationType.inv_sqrt

    if activation is None:
        raise UnsupportedActivationLayerError(f"Unexpected activation at {vertex.name}, op={op}")

    orig_name = vertex.name if not is_fused_activation else f"{vertex.name}{FUSED_ACTIVATION_SUFFIX}"
    vertex_input_order = vertex.input if not is_fused_activation else [vertex.name]
    layer = ActivationLayer.create(
        orig_name,
        vertex_input_order,
        activation,
        output_shapes=vertex.output_shapes,
        leaky_alpha=leaky_alpha,
        activation_threshold=activation_threshold,
        prelu_slope=prelu_slope,
        swish_beta=swish_beta,
        delta_bias=delta,
        activation_values=activation_values,
        hardsigmoid_alpha=hardsigmoid_alpha,
        hardsigmoid_beta=hardsigmoid_beta,
        clip_min=clip_min,
        clip_max=clip_max,
    )

    return layer, consumed_vertices, None


def _create_ew_const_input_layer(vertex):
    input_name, input_values = vertex.get_ew_const_input_info()
    if input_values is None:
        return None

    if len(input_values.shape) != 4 or input_values.shape[0] != 1:
        raise UnsupportedConstInputError(
            "TFLite currently supports const input tensor with rank 4 and batch size 1. "
            f"Got shape {input_values.shape} at {vertex.name}.",
        )

    input_values = np.squeeze(input_values, axis=0)
    return ConstInputLayer.create(input_name, [[-1, *input_values.shape]], input_values)


def _create_ew_add_layer(vertex):
    fused_activation = vertex.get_ew_op_fused_activation()
    const_layer = None
    if vertex.is_ew_add_with_const_input():
        const_layer = _create_ew_const_input_layer(vertex)
    layer = EWAddLayer.create(vertex.name, vertex.input, vertex.output_shapes)
    return layer, [], fused_activation, const_layer


def _create_ew_add_n_layer(vertex):
    fused_activation = vertex.get_ew_op_fused_activation()
    layer = EWAddNLayer.create(vertex.name, vertex.input, vertex.output_shapes)
    return layer, [], fused_activation


def _create_ew_sub_layer(vertex):
    fused_activation = vertex.get_ew_op_fused_activation()
    layer = EWSubLayer.create(vertex.name, vertex.input, output_shapes=vertex.output_shapes)

    return layer, [], fused_activation


def _create_ew_mult_layer(vertex):
    fused_activation = vertex.get_ew_op_fused_activation()
    layer = EWMultLayer.create(vertex.name, vertex.input, vertex.output_shapes)

    return layer, [], fused_activation


def _create_ew_div_layer(vertex):
    fused_activation = vertex.get_ew_op_fused_activation()
    layer = EWDivLayer.create(vertex.name, vertex.input, vertex.output_shapes)

    return layer, [], fused_activation


def _create_ew_max_layer(vertex):
    fused_activation = vertex.get_ew_op_fused_activation()
    layer = EWMaxLayer.create(vertex.name, vertex.input, vertex.output_shapes)

    return layer, [], fused_activation


def _create_ew_min_layer(vertex):
    fused_activation = vertex.get_ew_op_fused_activation()
    layer = EWMinLayer.create(vertex.name, vertex.input, vertex.output_shapes)

    return layer, [], fused_activation


def _create_square_layer(vertex):
    layer = FeatureMultiplierLayer.create(
        vertex.name,
        vertex.input,
        feature_multiplier_type=FeatureMultiplierType.square,
        output_shapes=vertex.output_shapes,
    )
    return layer, [], None


def _create_dense_layer(vertex):
    consumed_vertices = []
    dense_attrs, vertex_kernel, bias = vertex.get_dense_info()

    if vertex is not None:
        # [f_out, f_in] (tflite repr) -> [f_in, f_out]
        kernel = np.transpose(vertex_kernel, [1, 0])
    else:
        UnsupportedModelError("Unsupported data driven fully-connected layer")

    layer = DenseLayer.create(vertex.name, vertex.input, bias, kernel, output_shapes=vertex.output_shapes)

    return layer, consumed_vertices, dense_attrs["fused_activation"]


def _create_l2_normalization_layer(vertex):
    # creates l2 normalization using layer normalization
    activation = vertex.get_l2_normalization_activation()
    layer_info = {
        "axes": [3],
        "epsilon": np.array(0),
        "scale": np.sqrt(1 / vertex.get_input_shapes()[0][-1]),
        "B": np.array(0),
        "rms_norm": True,
    }
    layer = LayerNormalizationLayer.create(vertex.name, vertex.input, layer_info, rms_norm=True)

    return layer, [], activation


def _create_pooling_layer(vertex):
    op = LayerType.avgpool if vertex.op in ["AVERAGE_POOL_2D", "MEAN"] else LayerType.maxpool

    pool_attrs = None
    if vertex.op in ["AVERAGE_POOL_2D", "MAX_POOL_2D"]:
        pool_attrs = vertex.get_pooling_info()
        should_set_kernel_to_input_shape = False
        padding = pool_attrs["padding"]
        strides = pool_attrs["strides"]
        kernel_shape = pool_attrs["kernel_shape"]
    elif vertex.op in ["MEAN"] and vertex.is_avgpool_reduce_mean():
        should_set_kernel_to_input_shape = False
        padding = PaddingType.valid
        kernel_shape, strides = vertex.get_avgpool_reduce_mean_info()
    elif vertex.op in ["MEAN", "REDUCE_MAX"]:
        should_set_kernel_to_input_shape = True
        padding = PaddingType.valid
        strides = None
        kernel_shape = None
        if vertex.op == "MEAN":
            vertex.validate_reduce_mean_info()

    else:
        raise UnsupportedPoolingLayerError(f"Unexpected pooling layer type {vertex.op} in vertex {vertex.name}")

    layer = PoolingLayer.create(
        vertex.name,
        vertex.input,
        op,
        kernel_shape,
        strides,
        padding,
        padding_vals=None,
        should_set_kernel_to_input_shape=should_set_kernel_to_input_shape,
        output_shapes=vertex.output_shapes,
        count_include_pad=False,
    )

    return layer, [], pool_attrs["fused_activation"] if pool_attrs is not None else None


def _create_softmax_layer(vertex):
    layer = SoftmaxLayer.create(vertex.name, vertex.input, output_shapes=vertex.output_shapes)
    return layer, [], None


def _create_argmax_layer(vertex):
    layer = ArgmaxLayer.create(vertex.name, vertex.input, output_shapes=vertex.output_shapes)
    return layer, [], None


def _create_normalization_layer(vertex):
    if vertex.is_mul_by_2_ew_add():
        attrs = {"fused_activation": vertex.get_ew_op_fused_activation()}
        mean = [0.0]
        std = [0.5]
        consumed_vertices = []
    else:
        attrs, mean, std, consumed_vertices = vertex.get_normalization_info()

    # if std = 1 or mean = 0, normalization is actually standalone mul, sub, or add
    is_mul_add = all(x == 1 for x in std) or all(x == 0 for x in mean)
    norm_type = NormalizationType.mul_and_add if is_mul_add else NormalizationType.normalization
    activation = vertex.get_normalization_activation(consumed_vertices)
    layer = NormalizationLayer.create(
        vertex.name,
        vertex.input,
        mean,
        std,
        normalization_type=norm_type,
        activation=activation,
    )

    return layer, consumed_vertices, attrs["fused_activation"]


def _create_external_pad_layer(vertex):
    paddings, consumed_vertices = vertex.get_padding_info()
    layer = ExternalPadLayer.create(
        vertex.name,
        vertex.input,
        padding_vals=paddings,
        output_shapes=vertex.output_shapes,
    )
    return layer, consumed_vertices, None


def _create_tile_layer(vertex):
    consumed_vertices = []
    multipliers, consumed_vertices = vertex.get_tile_multipliers_info()

    if multipliers is None:
        raise UnsupportedTileLayerError("Tile layer with no multipliers is not supported")
    filtered_multipliers = [(idx, elem) for idx, elem in enumerate(multipliers) if elem != 1]
    if len(filtered_multipliers) > 1:
        raise UnsupportedTileLayerError("Tile layer with multi-axes-tiling is not supported")

    axis, multipliers = filtered_multipliers[0]
    axis = CONCAT_DIM_TO_AXIS.get(axis)
    layer = ConcatLayer.create(
        vertex.name,
        [vertex.input[0]] * multipliers,
        output_shapes=vertex.output_shapes,
        axis=axis,
    )
    return layer, consumed_vertices, None


def _create_concat_layer(vertex):
    const_layer = None
    input_name, input_val = vertex.get_concat_const_input_info()
    if input_val is not None:
        const_layer_output_shape = [-1, *list(input_val.shape)]
        const_layer = ConstInputLayer.create(input_name, [const_layer_output_shape], input_val)

    concat_attrs = vertex.get_concat_info()
    output_shape = vertex.output_shapes[0]
    vertex_axis = concat_attrs["axis"]

    axis = DEFAULT_CONCAT_AXIS
    if len(output_shape) == 4 and (vertex_axis in (2, -2)):
        axis = ConcatAxis.spatial_w
    elif len(output_shape) == 4 and (vertex_axis in (1, -3)):
        axis = ConcatAxis.spatial_h

    layer = ConcatLayer.create(vertex.name, vertex.input, output_shapes=vertex.output_shapes, axis=axis)
    return layer, [], concat_attrs["fused_activation"], const_layer


def _create_feature_split_layer(vertex):
    features_split_dims, split_indices, output_shapes = vertex.get_feature_split_info()
    if len(output_shapes) == 1:
        # Support for feature splitter with one output, converted automatically to slice layer
        start_slice_index = split_indices[0]
        end_slice_index = start_slice_index + output_shapes[0][3]
        layer = SliceLayer.create(
            vertex.name,
            vertex.input,
            [0, output_shapes[0][1]],
            [0, output_shapes[0][2]],
            [start_slice_index, end_slice_index],
            output_shapes=output_shapes,
        )
    else:
        layer = FeatureSplitterLayer.create(
            vertex.name,
            vertex.input,
            features_split_dims,
            output_shapes=vertex.output_shapes,
        )
    return layer, [], None


def _create_resize_layer(vertex):
    consumed_vertices = []
    if vertex.op in PACK_OPS:
        resize_attrs, consumed_vertices = vertex.get_resize_pack_info()
        w_sizes = resize_attrs["w_sizes"]
        h_sizes = resize_attrs["h_sizes"]
        d_sizes = resize_attrs["d_sizes"]
        resize_method = resize_attrs["resize_method"]
        pixels_mode = resize_attrs["pixels_mode"]
        output_shapes = resize_attrs["output_shapes"]
    else:
        resize_method, resize_sizes, pixels_mode = vertex.get_resize_info()
        if resize_sizes:
            h_sizes, w_sizes, d_sizes = resize_sizes[-2], resize_sizes[-1], None
            if len(resize_sizes) == 5:
                d_sizes = resize_sizes[-3]
        else:
            h_sizes, w_sizes, d_sizes = None, None, None

        output_shapes = vertex.output_shapes
    layer = ResizeLayer.create(
        vertex.name,
        vertex.input,
        resize_method=resize_method,
        output_shapes=output_shapes,
        h_sizes=h_sizes,
        w_sizes=w_sizes,
        d_sizes=d_sizes,
        resize_bilinear_pixels_mode=pixels_mode,
    )

    return layer, consumed_vertices, None


def _create_reduce_max_layer(vertex):
    if not vertex.is_valid_reduce_max_min():
        raise UnsupportedReduceMaxLayerError(
            f"Failed to create reduce max layer at vertex {vertex.name}. Reduce "
            f"max is only supported on the features axis, and with keepdim=True",
        )

    layer = ReduceMaxLayer.create(vertex.name, vertex.input, output_shapes=vertex.output_shapes)

    return layer, [], None


def _create_reduce_min_layer(vertex):
    if not vertex.is_valid_reduce_max_min():
        raise UnsupportedReduceMinLayerError(
            f"Failed to create reduce min layer at vertex {vertex.name}. Reduce "
            f"min is only supported on the features axis, and with keepdim=True",
        )

    layer = ReduceMinLayer.create(vertex.name, vertex.input, output_shapes=vertex.output_shapes)

    return layer, [], None


def _create_shuffle_layer(vertex):
    second_reshape, perm, consumed_vertices = vertex.get_shuffle_reshape_transpose_info()
    first_reshape_shape = vertex.output_shapes[0]

    # Reshape to rank 5 and transpose perm=[0,1,2,4,3] -> FeatureShuffle, reshape[3] is num of groups.
    if len(first_reshape_shape) == 5 and perm == [0, 1, 2, 4, 3]:
        layer = FeatureShuffleLayer.create(
            vertex.name,
            vertex.input,
            groups=first_reshape_shape[3],
            output_shapes=second_reshape.output_shapes,
        )
    else:
        raise UnsupportedShuffleLayerError(f"Unable to create shuffle layer at {vertex.name}")

    return layer, consumed_vertices, None


def _create_depth_to_space_layer(vertex):
    block_size = vertex.get_d2s_block_size()
    consumed_vertices, output_shapes = vertex.get_depth_to_space_info()
    d2s_type = DepthToSpaceType.dcr  # TF only supports DCR
    layer = DepthToSpaceLayer.create(
        vertex.name,
        vertex.input,
        block_size=block_size,
        output_shapes=output_shapes,
        depth_to_space_type=d2s_type,
    )
    return layer, consumed_vertices, None


def _create_reduce_sum_layer(vertex):
    is_valid, axes = vertex.get_reduce_sum_info()
    if not is_valid:
        raise UnsupportedReduceSumLayerError(
            f"Failed to create reduce sum layer at vertex {vertex.name}. Reduce "
            f"sum is only supported on the height, width or features axis, and "
            f"with keepdim=True",
        )

    layer = ReduceSumLayer.create(vertex.name, vertex.input, axes, output_shapes=vertex.output_shapes)

    return layer, [], None


def _create_null_layer(vertex):
    consumed_vertices = vertex.get_null_vertices()
    layer = NullLayer.create(vertex.name, vertex.input, vertex.output_shapes)
    return layer, consumed_vertices, None


def _create_reduce_l2_layer(vertex):
    is_valid, axes, consumed_vertices = vertex.get_reduce_l2_info()
    if not is_valid:
        raise UnsupportedReduceL2LayerError(
            f"Failed to create reduce L2 layer at vertex {vertex.name}. Reduce L2 "
            f"is only supported on the height, width or features axis, and with "
            f"keepdim=True",
        )

    layer = ReduceL2Layer.create(vertex.name, vertex.input, axes, output_shapes=vertex.output_shapes)

    return layer, consumed_vertices, None


def _create_slice_layer(vertex):
    height_slice, width_slice, features_slice = vertex.get_slices_values()
    layer = SliceLayer.create(
        vertex.name,
        vertex.input,
        height_slice,
        width_slice,
        features_slice,
        output_shapes=vertex.output_shapes,
    )

    return layer, [], None


def _create_space_to_depth_layer(vertex, space_to_depth_type=SpaceToDepthType.classic_dcr):
    consumed_vertices = []
    output_shapes = vertex.output_shapes

    if space_to_depth_type == SpaceToDepthType.focus:
        block_size = 2
        output_shapes = [shape[:-1] + [shape[-1] * block_size * block_size] for shape in output_shapes]
        consumed_vertices = vertex.get_space_to_depth_consumed_vertices()
    else:
        block_size = vertex.get_space_to_depth_block_size()

    if block_size != 2 and space_to_depth_type == SpaceToDepthType.focus:
        raise UnsupportedSpaceToDepthLayerError(
            f"Space to depth layers focus type are only supported with block size "
            f"of 2, while in node {vertex.name} the block size is {block_size}",
        )
    layer = SpaceToDepthLayer.create(
        vertex.name,
        vertex.input,
        [block_size] * 2,
        output_shapes=output_shapes,
        space_to_depth_type=space_to_depth_type,
    )

    return layer, consumed_vertices, None


def _create_format_conversion_layer(vertex):
    shapes = vertex.output_shapes[0]

    if vertex.is_flat_to_frames_reshape():
        layer = FormatConversionLayer.create(vertex.name, vertex.input, FormatConversionType.flat_to_frames, shapes)

    elif vertex.is_features_reshape():
        layer = FormatConversionLayer.create(
            vertex.name,
            vertex.input,
            FormatConversionType.features_to_width_features,
            shapes,
        )

    elif vertex.is_width_features_transpose():
        layer = FormatConversionLayer.create(
            vertex.name,
            vertex.input,
            FormatConversionType.transpose_width_features,
            shapes,
        )

    elif vertex.is_height_width_transpose():
        layer = FormatConversionLayer.create(
            vertex.name,
            vertex.input,
            FormatConversionType.transpose_height_width,
            shapes,
        )

    else:
        raise UnsupportedFormatConversionLayerError(f"Unable to create format conversion layer at {vertex.name}")

    return layer, [], None


def _create_equal_layer(vertex):
    values, _ = vertex.get_normalization_node_values()
    if values is None and len(vertex.input) == 1:
        raise UnsupportedEqualLayerError(f"Unable to find equal inputs for {vertex.name}")
    layer = EqualLayer.create(vertex.name, vertex.input, values, vertex.output_shapes)
    return layer, [], None
