import tensorflow as tf

""" Module contents:

    Providing wrapping services over miscellaneous keras/tensorflow quirks.
"""


def get_tf_same_padding(dilations, input_h, input_w, kernel_h, kernel_w, strides_h, strides_w):
    """
    Emulate the padding that's done internally by the "SAME" setting of tensorflow conv2d,
    which (!!) aligns the sampling grid to SE (south-east)
     in the ambiguous case of strided conv on even-sized input.
    """
    dil_h, dil_w = dilations or [1, 1]
    dilation_kernel_h = kernel_h + (kernel_h - 1) * (dil_h - 1)
    dilation_kernel_w = kernel_w + (kernel_w - 1) * (dil_w - 1)

    # This calculation is supposed to find the minimum padding that keeps the size
    if input_h % strides_h == 0:
        pad_total_h = max(dilation_kernel_h - strides_h, 0)
    else:
        pad_total_h = max(dilation_kernel_h - (input_h % strides_h), 0)
    if input_w % strides_w == 0:
        pad_total_w = max(dilation_kernel_w - strides_w, 0)
    else:
        pad_total_w = max(dilation_kernel_w - (input_w % strides_w), 0)

    # this padding is done by tensorflow's SAME
    pad_beg_h = pad_total_h // 2
    pad_end_h = pad_total_h - pad_beg_h
    pad_beg_w = pad_total_w // 2
    pad_end_w = pad_total_w - pad_beg_w

    return pad_beg_h, pad_end_h, pad_beg_w, pad_end_w


def hailo_reciprocal(x, epsilon=1e-10, tf_type=tf.float32):
    """
    Wrap tf.math.reciprocal and add epsilon to x before computing the reciprocal element-wise.

    I.e., if x >= 0 then \\(y = 1 / \\(x + epsilon\\)\\) else \\(y = 1 / \\(x - epsilon\\)\\)

    Args:
        x: A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`,
            `int8`, `int16`, `int32`, `int64`, `complex64`, `complex128`.
        epsilon: A `float`. Default is 1e-10.

    Returns:
        A `Tensor`. Has the same type as `x`.

    """
    epsilon = tf.cast(epsilon, tf_type)
    s = tf.math.less_equal(tf.cast(0.0, tf_type), x)
    s = tf.cast(s, tf_type) - tf.cast(tf.math.logical_not(s), tf_type)
    return tf.math.reciprocal(x + s * epsilon)
