#!/usr/bin/env python

UNSUPPORTED_CHARS = [";"]
DEFAULT_REPLACEMENT_CHAR = "_"
FUSED_ACTIVATION_SUFFIX = "_fused_activation"


def valid_orig_name(orig_name):
    valid_name = orig_name
    for key in UNSUPPORTED_CHARS:
        valid_name = valid_name.replace(key, DEFAULT_REPLACEMENT_CHAR)
    if valid_name.endswith(FUSED_ACTIVATION_SUFFIX):
        valid_name = valid_name.replace(FUSED_ACTIVATION_SUFFIX, "")
    return valid_name


def map_hn_orig_names_to_orig_names(hn_model, orig_model_names):
    hn_orig_name_to_vertex_name = {}
    for layer in hn_model:
        if layer.original_names:
            for orig_name in layer.original_names:
                vertex_name = [name for name in orig_model_names if valid_orig_name(name) == orig_name]
                if vertex_name:
                    hn_orig_name_to_vertex_name[orig_name] = vertex_name[0]

    return hn_orig_name_to_vertex_name


def is_feature_repeats(layer_shape, neighbor_shape):
    if layer_shape[-1] == neighbor_shape[-1]:
        return False
    return max(neighbor_shape[-1], layer_shape[-1]) % min(neighbor_shape[-1], layer_shape[-1]) == 0


def is_spatial_broadcast(first_input_shape, second_input_shape, is_two_sided=False):
    # this function is used in element wise layers and checks if the current inputs of the ew layer
    # should be spatial broadcasted
    # the following cases are supported:
    # [1, w] and [h, w] / [h, w] and [1, w] --> [h, w] and [h, w]
    # [h, 1] and [h, w] / [h, w] and [h, 1] --> [h, w] and [h, w]
    # where c is a scalar different than 1

    if len(first_input_shape) == 2 or len(second_input_shape) == 2:
        return False

    # changes input shape to [h,w] if needed
    first_input_shape = first_input_shape[1:-1] if len(first_input_shape) > 2 else first_input_shape
    second_input_shape = second_input_shape[1:-1] if len(second_input_shape) > 2 else second_input_shape

    height_max_spatial_value = max(first_input_shape[0], second_input_shape[0])
    width_max_spatial_value = max(first_input_shape[1], second_input_shape[1])

    def spatial_condition(input_1, input_2, max_value):
        return input_1 != input_2 and (
            input_2 == max_value and input_1 == 1 or is_two_sided and input_1 == max_value and input_2 == 1
        )

    if spatial_condition(first_input_shape[0], second_input_shape[0], height_max_spatial_value):
        # one of height axis is one, the other is bigger
        return True

    if spatial_condition(first_input_shape[1], second_input_shape[1], width_max_spatial_value):
        # one of width axis is one, the other is bigger
        return True

    return False
