#!/usr/bin/env python
import numpy as np

from hailo_sdk_client.model_translator.exceptions import (
    RecordableParserError,
    UnsupportedEWLayerError,
    UnsupportedPaddingError,
    UnsupportedReduceL2LayerError,
    UnsupportedReduceMaxLayerError,
    UnsupportedReduceSumLayerError,
    UnsupportedResizeLayerError,
    UnsupportedSpaceToDepthLayerError,
    UnsupportedSquareLayerError,
    UnsupportedStridesError,
)
from hailo_sdk_client.model_translator.tf_translator.tf_graph import (
    DIV_OPS,
    EINSUM_OPS,
    OTHER_OPS,
    SPLIT_OPS,
    SQUARE_OPS,
    VAR_OPS,
)
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import (
    ActivationType,
    FeatureMultiplierType,
    FormatConversionType,
    LayerType,
    NormalizationType,
    PaddingType,
    ResizeMethod,
    SpaceToDepthType,
)
from hailo_sdk_common.hailo_nn.hn_layers import (
    ActivationLayer,
    ArgmaxLayer,
    BatchNormLayer,
    BiasAddLayer,
    ConcatLayer,
    Conv2DLayer,
    DenseLayer,
    DepthToSpaceLayer,
    EWAddLayer,
    EWAddNLayer,
    EWDivLayer,
    EWMultLayer,
    EWSubLayer,
    ExternalPadLayer,
    FeatureMultiplierLayer,
    FeatureShuffleLayer,
    FeatureSplitterLayer,
    FormatConversionLayer,
    InputLayer,
    L2NormalizationLayer,
    MatmulLayer,
    NormalizationLayer,
    PoolingLayer,
    ReduceL2Layer,
    ReduceMaxLayer,
    ReduceSumLayer,
    ResizeLayer,
    SliceLayer,
    SoftmaxLayer,
    SpaceToDepthLayer,
)
from hailo_sdk_common.logger.logger import default_logger


def create_layer_from_vertex(layer_type, vertex, errors_dict, **kwargs):
    try:
        if layer_type == LayerType.input_layer:
            return _create_input_layer(vertex, **kwargs)
        elif layer_type == LayerType.conv:
            return _create_convolutional_layer(vertex, **kwargs)
        elif layer_type == "pool":
            return _create_pooling_layer(vertex, **kwargs)
        elif layer_type == LayerType.dense:
            return _create_dense_layer(vertex)
        elif layer_type == LayerType.batch_norm:
            return _create_batch_norm_layer(vertex)
        elif layer_type == LayerType.ew_add:
            return _create_ew_add_layer(vertex, **kwargs)
        elif layer_type == LayerType.base_ew_add_n:
            return _create_ew_add_n_layer(vertex, **kwargs)
        elif layer_type == LayerType.base_ew_sub:
            return _create_ew_sub_layer(vertex)
        elif layer_type == LayerType.slice:
            return _create_slice_layer(vertex)
        elif layer_type == LayerType.concat:
            return _create_concat_layer(vertex)
        elif layer_type == LayerType.resize:
            return _create_resize_layer(vertex)
        elif layer_type == LayerType.feature_shuffle:
            return _create_feature_shuffle_layer(vertex)
        elif layer_type == LayerType.depth_to_space:
            return _create_depth_to_space_layer(vertex, **kwargs)
        elif layer_type == LayerType.argmax:
            return _create_argmax_layer(vertex)
        elif layer_type == LayerType.softmax:
            return _create_softmax_layer(vertex)
        elif layer_type == LayerType.feature_splitter:
            return _create_feature_splitter_layer(vertex, **kwargs)
        elif layer_type == LayerType.bias_add:
            return _create_bias_add_layer(vertex)
        elif layer_type == LayerType.format_conversion:
            return _create_format_conversion_layer(vertex, **kwargs)
        elif layer_type == LayerType.ew_mult:
            return _create_ew_mult_layer(vertex)
        elif layer_type == LayerType.ew_div:
            return _create_ew_div_layer(vertex)
        elif layer_type == LayerType.normalization:
            return _create_normalization_layer(vertex)
        elif layer_type == LayerType.space_to_depth:
            return _create_space_to_depth_layer(vertex, **kwargs)
        elif layer_type == LayerType.external_pad:
            return _create_external_pad_layer(vertex)
        elif layer_type == LayerType.reduce_sum:
            return _create_reduce_sum_layer(vertex)
        elif layer_type == LayerType.matmul:
            return _create_matmul_layer(vertex)
        elif layer_type == LayerType.feature_multiplier and vertex.op in SQUARE_OPS:
            return _create_square_layer(vertex)
        elif layer_type == LayerType.reduce_l2:
            return _create_reduce_l2_layer(vertex)
        elif layer_type == LayerType.l2_normalization:
            return _create_l2_normalization_layer(vertex)
    except RecordableParserError as e:
        default_logger().debug(
            f"An error was encountered while trying to create {layer_type} layer from vertex {vertex}",
        )
        errors_dict[vertex] = (e,)
        return None
    # Activation layers errors are not recordable, because if no layer is created, we cannot continue with the
    # state machine, as it depends on the current layer to maintain its correctness.
    if layer_type == LayerType.activation:
        return _create_activation_layer(vertex, **kwargs)
    raise UnsupportedModelError(f"Layer type {layer_type} is not supported")


def _create_input_layer(vertex, external_output_shapes=None):
    return InputLayer.create(vertex.name, external_output_shapes)


def _create_convolutional_layer(vertex, is_bias=False, is_dilation_s2b=False):
    layer = Conv2DLayer()
    _set_basic_layer_params(layer, vertex)

    should_set_op = (not is_bias) and (not is_dilation_s2b)
    if should_set_op:
        set_op(LayerType.conv, layer, vertex)
    if is_bias:
        _set_pre_layer_bias(layer, vertex)
    elif is_dilation_s2b:
        # this case handles dilated conv layers that were defined in some high level API (e.g. slim)
        layer.is_dilated_s2b = True
        layer.dilations = vertex.get_dilations(is_dilation_s2b)
        layer.padding = vertex.get_dilated_s2b_padding(layer.dilations)

    if vertex.op == "DepthwiseConv2dNative" and layer.kernel_shape[-1] == 1:
        layer.op = LayerType.base_dw
    elif vertex.op == "Conv2DBackpropInput":
        if layer.strides[1] != layer.strides[2]:
            raise UnsupportedStridesError(
                f"Deconv layer (translated from {vertex.name}) only supported with equal strides (deconv rate)",
            )
        if layer.padding == PaddingType.valid:
            raise UnsupportedPaddingError(
                f"Deconv layer (translated from {vertex.name}) only supported with padding type SAME_TENSORFLOW",
            )
        layer.op = LayerType.base_deconv
        if layer.kernel_height != 2:
            layer.padding = PaddingType.deconv

        # In order to stay consistent with hw and emulator implementation, we keep the deconv kernel in shape of
        # [w, h, in_c, out_c] instead of TF representation which is [w, h, out_c, in_v].
        layer.kernel_shape[-2:] = reversed(layer.kernel_shape[-2:])
        layer.kernel = np.transpose(layer.kernel, [0, 1, 3, 2])

    return layer


def _create_pooling_layer(vertex):
    op = LayerType.avgpool if vertex.op in ["AvgPool", "Mean"] else LayerType.maxpool

    if vertex.op in ["AvgPool", "MaxPool"]:
        output_shapes = vertex.get_output_shapes()
        kernel_shape = vertex.get_pooling_ksize()
        strides = vertex.get_strides()
        padding = vertex.get_padding_from_op()
        should_set_kernel_to_input_shape = False

    elif vertex.op == "Mean" or (vertex.op == "Max" and vertex.is_global_max_pool()):
        if vertex.op == "Mean":
            vertex.validate_global_avg_pool()
        kernel_shape = None
        strides = None
        padding = PaddingType.valid
        should_set_kernel_to_input_shape = True
        output_shapes = [[shape[0], 1, 1, shape[-1]] for shape in vertex.get_output_shapes()]

    elif vertex.op == "Max":
        return _create_reduce_max_layer(vertex)

    if vertex.op == "MaxPool" and vertex.is_1d_maxpool():
        for shape in output_shapes:
            shape[1:3] = [shape[2], shape[1]]
        kernel_shape[1:3] = [kernel_shape[2], kernel_shape[1]]
        strides[1:3] = [strides[2], strides[1]]

    return PoolingLayer.create(
        vertex.name,
        vertex.get_inputs(),
        op,
        kernel_shape,
        strides,
        padding,
        should_set_kernel_to_input_shape=should_set_kernel_to_input_shape,
        output_shapes=output_shapes,
        count_include_pad=False,
    )


def _create_dense_layer(vertex):
    kernel, _ = _fill_from_variable(vertex)
    is_nchw = vertex.graph.is_nchw
    if is_nchw:
        kernel = kernel.transpose()
    return DenseLayer.create(vertex.name, vertex.get_inputs(), None, kernel, is_nchw, vertex.get_output_shapes())


def _create_batch_norm_layer(vertex):
    return BatchNormLayer.create(vertex.name, vertex.get_inputs(), vertex.get_bn_info(), vertex.get_output_shapes())


def _create_activation_layer(vertex, is_relu6=False, is_leaky_max=False, is_keras_prelu=False):
    leaky_alpha = None
    activation_threshold = None
    delta_bias = None
    prelu_slope = None
    swish_beta = None
    activation_less_vales = None
    hardsigmoid_alpha = None
    hardsigmoid_beta = None
    clip_min = None
    clip_max = None
    if vertex.op == "Relu" and not is_relu6 and not is_keras_prelu:
        activation = ActivationType.relu
    elif vertex.op == "Relu6" or is_relu6:
        activation = ActivationType.relu6
    elif vertex.op == "Elu":
        activation = ActivationType.elu
    elif vertex.op == "Sigmoid":
        activation = ActivationType.sigmoid
    elif vertex.op == "Exp":
        activation = ActivationType.exp
    elif vertex.op == "LeakyRelu" or is_leaky_max:
        vertex_leaky_alpha = vertex.get_leaky_alpha(is_leaky_max=is_leaky_max)
        if vertex_leaky_alpha < 0:
            activation = ActivationType.prelu
            prelu_slope = [vertex_leaky_alpha]
        else:
            activation = ActivationType.leaky
            leaky_alpha = vertex_leaky_alpha
    elif vertex.op == "Tanh":
        activation = ActivationType.tanh
    elif vertex.op == "Mul":
        if vertex.is_threshold_activation():  # Threshold activation
            activation = ActivationType.threshold
            activation_threshold = vertex.get_threshold()
        elif vertex.is_mish_activation():
            activation = ActivationType.mish
        elif vertex.is_hardswish_activation():
            activation = ActivationType.hardswish
        elif vertex.is_hardsigmoid():
            activation = ActivationType.hardsigmoid
            hardsigmoid_alpha, hardsigmoid_beta = vertex.get_hardsigmoid_info()
        elif vertex.is_swish_activation_second_mul():
            activation = ActivationType.swish
            swish_beta = vertex.get_swish_beta()
        else:
            raise UnsupportedModelError(f"Unexpected activation at {vertex.name}, op={vertex.op}")
    elif vertex.op in DIV_OPS:
        if vertex.is_gelu_activation():
            activation = ActivationType.gelu
        elif vertex.is_inv_pos_activation():
            activation = ActivationType.inv_pos
    elif vertex.op in ["Sign", "Abs"]:
        activation = ActivationType.biased_delta
        delta_bias = vertex.get_delta_bias_value()
    elif vertex.op in ["Relu"] and is_keras_prelu:
        vertex_prelu_slope = vertex.get_prelu_slope()
        if np.any(vertex_prelu_slope != vertex_prelu_slope[0]):
            activation = ActivationType.prelu
            prelu_slope = vertex_prelu_slope
        else:
            activation = ActivationType.leaky
            leaky_alpha = vertex_prelu_slope[0]
    elif vertex.op == "Softplus":
        activation = ActivationType.softplus
    elif vertex.op == "IdentityN":
        activation = ActivationType.silu
    elif vertex.op == "Sqrt":
        activation = ActivationType.sqrt
    elif vertex.op == "Less":
        activation = ActivationType.less
        activation_less_vales = vertex.get_activation_less_values()
    elif vertex.op == "Log":
        activation = ActivationType.log
    elif vertex.op in ("Minimum", "Maximum"):
        activation = ActivationType.clip
        clip_min, clip_max = vertex.get_min_max_info()
    elif vertex.op == "Softsign":
        activation = ActivationType.softsign
    else:
        raise UnsupportedModelError(f"Unexpected activation at {vertex.name}, op={vertex.op}")
    return ActivationLayer.create(
        vertex.name,
        vertex.get_inputs(),
        activation,
        leaky_alpha,
        activation_threshold,
        delta_bias,
        vertex.get_output_shapes(),
        prelu_slope,
        swish_beta,
        activation_less_vales,
        hardsigmoid_alpha=hardsigmoid_alpha,
        hardsigmoid_beta=hardsigmoid_beta,
        clip_min=clip_min,
        clip_max=clip_max,
    )


def _create_ew_add_layer(vertex):
    inputs = vertex.get_inputs()
    if len(inputs) != 2:
        raise UnsupportedEWLayerError(f"Number of inputs in add op {vertex.name} is not 2, and is not supported")
    return EWAddLayer.create(vertex.name, inputs, vertex.get_output_shapes())


def _create_ew_add_n_layer(vertex):
    return EWAddNLayer.create(vertex.name, vertex.get_inputs(), vertex.get_output_shapes())


def _create_ew_sub_layer(vertex):
    inputs = vertex.get_inputs()
    if len(inputs) != 2:
        raise UnsupportedEWLayerError(f"Number of inputs in sub op {vertex.name} is not 2, and is not supported")
    return EWSubLayer.create(vertex.name, inputs, output_shapes=vertex.get_output_shapes())


def _create_slice_layer(vertex):
    height_slice, width_slice, features_slice = vertex.get_slices_values()
    return SliceLayer.create(
        vertex.name,
        vertex.get_inputs(),
        height_slice,
        width_slice,
        features_slice,
        vertex.get_output_shapes(),
    )


def _create_concat_layer(vertex):
    return ConcatLayer.create(
        vertex.name,
        vertex.get_inputs(),
        output_shapes=vertex.get_output_shapes(),
        axis=vertex.get_concat_axis(),
    )


def _create_resize_layer(vertex):
    # Note: We don't call _validate_ratios when creating a layer from TF.
    # The TF graph's resize node may have ratios greater than MAXIMUM_RESIZE_RATIO_PER_LAYER.
    # This gets dealt with in the Fuser:
    # * We split the resize ratios into a list, where:
    #   * Each ratio in the list is at most MAXIMUM_RESIZE_RATIO_PER_LAYER.
    #   * The product of all the ratios == the original ratio.
    # * E.g. TF resize -> h_ratios=[32] -> Fuser -> h_ratios=[16, 2]
    #        (16, 2 <= MAXIMUM_RESIZE_RATIO_PER_LAYER; 16 * 2 == 32)
    if vertex.op in SPLIT_OPS:
        resize_method = ResizeMethod.nearest_neighbor
        upscale_factors, output_shapes = vertex.get_1d_resize_info()
    elif vertex.op == "Reshape":
        resize_method = ResizeMethod.nearest_neighbor
        upscale_factors, output_shapes = vertex.get_reshape_as_resize_nearest_info()
    else:
        output_shapes = vertex.get_output_shapes()
        upscale_factors = None
        resize_info = vertex.get_resize_info()
        if not resize_info.forced_by_unknown_shape:
            if resize_info.is_upscale_factors:
                upscale_factors = list(resize_info.expected_output_shape)
            elif (resize_info.expected_output_shape is not None) and (
                not all(
                    dim == expected_dim
                    for dim, expected_dim in zip(output_shapes[1:3], resize_info.expected_output_shape)
                )
            ):
                raise UnsupportedResizeLayerError(f"Unexpected resize layer output shapes at {vertex.name}")
        else:
            output_shapes = resize_info.expected_output_shape

        if vertex.op == "ResizeNearestNeighbor":
            resize_method = ResizeMethod.nearest_neighbor
        elif vertex.op == "ResizeBilinear":
            resize_method = ResizeMethod.bilinear
        else:
            raise UnsupportedResizeLayerError(
                f"Unsupported resize method at {vertex.name}. Received method={vertex.op}",
            )
        # Note: For now the 'hw_layer_type_list' field will only be used for resize layers.

    pixels_mode = vertex.get_resize_bilinear_pixels_mode()
    return ResizeLayer.create(
        vertex.name,
        vertex.get_inputs(),
        resize_method,
        output_shapes=output_shapes,
        upscale_factors=upscale_factors,
        resize_bilinear_pixels_mode=pixels_mode,
    )


def _create_feature_shuffle_layer(vertex):
    first_reshape_shape = vertex.get_output_shapes()[0]
    groups = first_reshape_shape[3]  # shape asserted to be 5 dimensional in is_shuffle()
    last_reshape = vertex.get_shuffle_last_reshape()
    if last_reshape is None:
        raise UnsupportedModelError(
            f"Unexpected entrance shuffle node without transpose->reshape for node: {vertex.name}",
        )
    output_shapes = last_reshape.get_output_shapes()
    return FeatureShuffleLayer.create(
        vertex.name,
        vertex.get_inputs(),
        groups=groups,
        last_reshape_name=last_reshape.name,
        output_shapes=output_shapes,
    )


def _create_depth_to_space_layer(vertex, is_asymmetric=False):
    if is_asymmetric:
        block_size, output_shapes = vertex.get_asymmetric_depth_to_space_params()
    else:
        block_size = vertex.get_depth_to_space_block_size()
        output_shapes = vertex.get_output_shapes()
    return DepthToSpaceLayer.create(vertex.name, vertex.get_inputs(), block_size, output_shapes)


def _create_argmax_layer(vertex):
    output_shapes = vertex.get_output_shapes()
    for out_shape in output_shapes:
        out_shape.append(1)
    return ArgmaxLayer.create(vertex.name, vertex.get_inputs(), output_shapes)


def _create_softmax_layer(vertex):
    return SoftmaxLayer.create(vertex.name, vertex.get_inputs(), output_shapes=vertex.get_output_shapes())


def _create_feature_splitter_layer(vertex):
    layer = FeatureSplitterLayer()
    _set_basic_layer_params(layer, vertex)

    if vertex.op in SPLIT_OPS:
        features_split_dims, split_indices, output_shapes = vertex.get_feature_split_info()
        if len(output_shapes) == 1:
            # Support for feature splitter with one output, converted automatically to slice layer
            return SliceLayer.create(
                vertex.name,
                vertex.get_inputs(),
                [0, output_shapes[0][1]],
                [0, output_shapes[0][2]],
                [split_indices[0], split_indices[0] + output_shapes[0][3]],
                output_shapes=output_shapes,
            )
        else:
            set_op(LayerType.feature_splitter, layer, vertex)
    return layer


def _create_bias_add_layer(vertex):
    bias_value, _ = _fill_from_variable(vertex)
    if len(bias_value.shape) == 2:
        bias_value = np.reshape(bias_value, [bias_value.shape[1]])
    return BiasAddLayer.create(vertex.name, vertex.get_inputs(), bias_value, vertex.get_output_shapes())


def _create_format_conversion_layer(vertex, conversion_type=FormatConversionType.tf_rgb_to_hailo_rgb):
    output_shapes = vertex.get_output_shapes()
    if conversion_type == FormatConversionType.flat_to_frames:
        output_shapes = [
            output_shape[:1] + [1] + output_shape[1:] if len(output_shape) == 3 else output_shape
            for output_shape in output_shapes
        ]
    return FormatConversionLayer.create(
        vertex.name,
        vertex.get_inputs(),
        conversion_type=conversion_type,
        output_shapes=output_shapes,
    )


def _create_ew_mult_layer(vertex):
    inputs = vertex.get_inputs()
    if len(inputs) != 2:
        raise UnsupportedEWLayerError(f"Number of inputs in mul op {vertex.name} is not 2, and is not supported")
    return EWMultLayer.create(vertex.name, inputs, vertex.get_output_shapes())


def _create_ew_div_layer(vertex):
    inputs = vertex.get_inputs()
    if len(inputs) != 2:
        raise UnsupportedEWLayerError(f"Number of inputs in div op {vertex.name} is not 2, and is not supported")
    return EWDivLayer.create(vertex.name, inputs, vertex.get_output_shapes())


def _create_normalization_layer(vertex):
    if vertex.is_mul_by_2_ew_add():
        mean = [0.0]
        std = [0.5]
        output_shapes = vertex.get_output_shapes()
    else:
        mean, std, output_shapes = vertex.get_normalization_info()
    # if std = 1 or mean = 0, normalization is actually standalone mul, sub, or add
    is_mul_add = all(x == 1 for x in std) or all(x == 0 for x in mean)
    norm_type = NormalizationType.mul_and_add if is_mul_add else NormalizationType.normalization
    return NormalizationLayer.create(
        vertex.name,
        vertex.get_inputs(),
        mean,
        std,
        output_shapes=output_shapes,
        normalization_type=norm_type,
    )


def _create_reduce_max_layer(vertex):
    if not vertex.is_valid_reduce_max_min():
        raise UnsupportedReduceMaxLayerError(
            f"Unexpected properties in reduce max layer created from vertex {vertex.name}. "
            "Reduce max is only supported in the features axis and with keepdims=True.",
        )
    return ReduceMaxLayer.create(vertex.name, vertex.get_inputs(), output_shapes=vertex.get_output_shapes())


def _create_space_to_depth_layer(vertex, space_to_depth_type=SpaceToDepthType.classic_dcr):
    output_shapes = vertex.get_output_shapes()
    if space_to_depth_type == SpaceToDepthType.focus:
        block_size = 2
        output_shapes = [shape[:-1] + [shape[-1] * block_size * block_size] for shape in output_shapes]
    else:
        block_size = vertex.get_space_to_depth_block_size()
    if block_size != 2 and space_to_depth_type == SpaceToDepthType.focus:
        raise UnsupportedSpaceToDepthLayerError(
            f"Space to depth layers focus type are only supported with block size "
            f"of 2, while in node {vertex.name} the block size is {block_size}",
        )
    return SpaceToDepthLayer.create(
        vertex.name,
        vertex.get_inputs(),
        [block_size] * 2,
        output_shapes=output_shapes,
        space_to_depth_type=space_to_depth_type,
    )


def _create_external_pad_layer(vertex):
    padding_vals = vertex.get_padding_from_const()
    if not padding_vals or len(padding_vals) != 6 or any(x < 0 for x in padding_vals):
        raise UnsupportedPaddingError(f"Illegal padding values in node {vertex.name}")
    return ExternalPadLayer.create(
        vertex.name,
        vertex.get_inputs(),
        output_shapes=vertex.get_output_shapes(),
        padding_vals=padding_vals,
    )


def _create_reduce_sum_layer(vertex):
    valid, axes = vertex.is_valid_reduce_sum()
    output_shapes = vertex.get_output_shapes()
    if not valid or 0 in axes:
        raise UnsupportedReduceSumLayerError(
            f"Unexpected properties in reduce sum layer created from vertex {vertex.name}. "
            "Reduce sum is only supported in the height, width or features axis and with keepdims=True.",
        )
    return ReduceSumLayer.create(vertex.name, vertex.get_inputs(), axes, output_shapes=output_shapes)


def _create_matmul_layer(vertex):
    inputs = vertex.get_inputs()[::-1] if vertex.should_transpose_einsum_matmul_input() else vertex.get_inputs()
    return MatmulLayer.create(
        vertex.name,
        inputs,
        output_shapes=vertex.get_output_shapes(),
        should_transpose_input=vertex.should_transpose_einsum_matmul_input(),
    )


def _create_square_layer(vertex):
    if vertex.op == "Pow":
        pow = vertex.get_power()
        if pow != 2.0:
            raise UnsupportedSquareLayerError(
                f"Pow operator {vertex.name} can only be supported as square (got power of {pow}).",
            )
    return FeatureMultiplierLayer.create(
        vertex.name,
        vertex.get_inputs(),
        feature_multiplier_type=FeatureMultiplierType.square,
        output_shapes=vertex.get_output_shapes(),
    )


def _create_reduce_l2_layer(vertex):
    valid, axes = vertex.is_valid_reduce_l2()
    if not valid or 0 in axes:
        raise UnsupportedReduceL2LayerError(
            f"Unexpected properties in reduce l2 layer created from vertex "
            f"{vertex.name}. Reduce sum is only supported in the height, width or "
            f"features axis and with keepdims=True.",
        )
    return ReduceL2Layer.create(vertex.name, vertex.get_inputs(), axes, output_shapes=(vertex.get_output_shapes()))


def _create_l2_normalization_layer(vertex):
    if len(vertex.get_input_shapes()[0]) < 4:
        raise UnsupportedReduceL2LayerError(
            f"Unexpected properties in l2  normalization layer created from vertex "
            f"{vertex.name}. l2 normalization is only supported with rank 4 input shape",
        )
    axes = vertex.get_sum_axes()
    scale = np.sqrt(1 / vertex.get_input_shapes()[0][axes[0]] if len(axes) == 1 else 1)
    return L2NormalizationLayer.create(
        vertex.name,
        vertex.input,
        axes,
        scale,
        output_shapes=(vertex.get_output_shapes()),
    )


def _set_basic_layer_params(layer, vertex):
    try:
        layer.output_shapes = vertex.get_output_shapes()
        layer.add_original_name(vertex.name)
        layer.input_vertex_order = vertex.get_inputs()
    except UnsupportedModelError as e:
        raise UnsupportedModelError(
            f"Error occurred in layer {layer.name} (translated from {vertex.name}): {e.client_message}",
        )


def set_op(layer_type, layer, vertex):
    if layer_type == LayerType.conv:
        _set_conv_op(layer, vertex)
    elif layer_type == LayerType.dense:
        _set_dense_op(layer, vertex)
    elif layer_type == LayerType.feature_splitter:
        _set_group_conv_splitter_op(layer, vertex)
    else:
        raise UnsupportedModelError(f"Set op with layer type {layer_type} is not supported")


def _set_conv_op(layer, vertex):
    layer.add_original_name(vertex.name)

    if vertex.op in EINSUM_OPS:
        layer.kernel, layer.kernel_shape = vertex.get_einsum_1x1_info()
        layer.strides, layer.dilations = [1, 1, 1, 1], [1, 1, 1, 1]
        layer.padding = PaddingType.valid
        layer.output_shapes = vertex.get_output_shapes()
        return

    non_var_preds = [x for x in vertex.graph.predecessors(vertex) if x.op not in VAR_OPS + OTHER_OPS]
    if len(non_var_preds) == 2 and len(vertex.get_input_shapes()) == 2 and vertex.op == "DepthwiseConv2dNative":
        layer.kernel_shape = vertex.get_input_shapes()[1][:3] + [1]
        layer.dynamic_weights = True
    else:
        layer.kernel, layer.kernel_shape = _fill_from_variable(vertex)

    if len(layer.kernel_shape) == 3:  # Conv1D
        layer.kernel_shape = [1, *layer.kernel_shape]
        layer.kernel = layer.kernel.reshape(layer.kernel_shape)
    if layer.pre_layer_bias is not None:
        layer.compute_pre_layer_bias_approx()
    layer.strides = vertex.get_strides()

    # this case handles regular conv layers with generic dilation (not s2b/b2s ops)
    if not layer.is_dilated_s2b:
        layer.dilations = vertex.get_dilations()
        layer.padding = _get_padding_from_op(layer.padding, vertex)

    # in this case from_vertex was called with a predecessing op (such as pad/space2batch)
    if vertex.op == "Conv2D":
        layer.op = LayerType.base_conv
    elif vertex.op == "DepthwiseConv2dNative":
        layer.op = LayerType.base_dw

    if layer.op in [LayerType.base_dw] and layer.kernel_shape[-1] > 1:
        layer.op = LayerType.base_conv
        groups = layer.kernel_shape[2]
        layer.groups = groups
        new_kernel_shape = [*layer.kernel_shape[:2], 1, groups * layer.kernel_shape[-1]]
        layer.kernel = layer.kernel.reshape(new_kernel_shape)
        new_kernel_shape[2] *= groups
        layer.kernel_shape = new_kernel_shape

    try:
        layer.output_shapes = vertex.get_output_shapes()
    except UnsupportedModelError as e:
        raise UnsupportedModelError(
            f"Error occurred in layer {layer.name} (translated from {vertex.name}): {e.client_message}",
        )


def _set_dense_op(layer, vertex):
    layer.add_original_name(vertex.name)
    layer.kernel, layer.kernel_shape = _fill_from_variable(vertex)
    if layer.pre_layer_bias is not None:
        layer.compute_pre_layer_bias_approx()


def _set_group_conv_splitter_op(layer, vertex):
    layer.add_original_name(vertex.name)
    layer.split_sizes, _, layer.output_shapes = vertex.get_feature_split_info()


def _fill_from_variable(vertex):
    return vertex.get_layer_var_data()


def _set_pre_layer_bias(layer, vertex):
    default_logger().warning(
        f"Vertex {vertex.name} is a pre-layer bias. This feature is supported in SDK using "
        "approximation only. It's recommended to run evaluation on the full-precision model "
        "before optimization to verify accuracy is still good enough.",
    )
    layer.add_original_name(vertex.name)
    layer.pre_layer_bias, _ = _fill_from_variable(vertex)


def set_dilation_op(layer, vertex, is_dilation_s2b):
    layer.is_dilated_s2b = is_dilation_s2b
    layer.dilations = vertex.get_dilations(is_dilation_s2b)


def set_external_output_shape(layer, vertex):
    # This case overrides previous update from set_op, since it was wrong to use the shape from conv2d when
    # the origin was a s2b->conv2d->b2s constellation for dilated conv op
    if layer.is_dilated_s2b:
        layer.add_original_name(vertex.name)
        layer.output_shapes = vertex.get_output_shapes()
    else:
        raise UnsupportedModelError(
            f"External output shape in {layer.op} layer (translated from {vertex.name}) can only be set with "
            "dilated conv2d s2b/b2s nodes",
        )


def _get_padding_from_op(layer_padding, vertex):
    op_padding = vertex.get_padding_from_op()

    if layer_padding is None:
        return op_padding

    if op_padding != PaddingType.valid:
        raise UnsupportedPaddingError(
            f"Expected VALID padding in input TF model after explicit TF padding node (node={vertex.name})",
        )
    if layer_padding != PaddingType.same:
        raise UnsupportedPaddingError(
            f"Expected SAME in HN padding after explicit TF padding node (node={vertex.name})",
        )
    # This is an external padding done manually. Currently we support one custom padding scheme.
    return layer_padding
