import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops._misc_internals import get_tf_same_padding
from hailo_model_optimization.acceleras.utils.acceleras_definitions import PaddingType, StrideAlignType
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasImplementationError


def handle_padding(input, padding, kernel_size, strides, padding_const_value, stride_align, dilation_rate=(1, 1)):
    unpadded_input = input
    if padding == PaddingType.VALID:
        padded_input = unpadded_input
    elif padding == PaddingType.SAME:  # DIY padding
        pad_beg_h, pad_end_h, pad_beg_w, pad_end_w = get_tf_same_padding(
            dilation_rate,
            *unpadded_input.shape[1:3],
            *kernel_size,
            *strides,
        )
        padded_input = diy_pad(
            unpadded_input,
            padding_const_value,
            stride_align,
            pad_beg_h,
            pad_end_h,
            pad_beg_w,
            pad_end_w,
        )
    elif padding == PaddingType.DECONV:  # Deconv padding
        pad_beg_h, pad_end_h, pad_beg_w, pad_end_w = get_deconv_padding(
            dilation_rate,
            unpadded_input.shape,
            kernel_size,
            strides,
        )
        padded_input = diy_pad(
            unpadded_input,
            padding_const_value,
            stride_align,
            pad_beg_h,
            pad_end_h,
            pad_beg_w,
            pad_end_w,
        )
    else:
        raise AccelerasImplementationError(f"Padding type {padding.value} is not supported")
    return padded_input


def get_deconv_padding(dilation_rate, shape, kernel_size, strides):
    pad_beg_h, pad_end_h, pad_beg_w, pad_end_w = get_tf_same_padding(dilation_rate, *shape[1:3], *kernel_size, *strides)
    pad_beg_h += pad_beg_h != pad_end_h
    pad_beg_w += pad_beg_w != pad_end_w
    return pad_beg_h, pad_end_h, pad_beg_w, pad_end_w


def diy_pad(inp, input_zero_point, stride_align, pad_beg_h, pad_end_h, pad_beg_w, pad_end_w):
    """
    padding input into the internal op (!with the zero-point instead of zeros!).
    just a useful common utility for the subclasses representing operations
    with a spatial kernel (conv,pool,etc.)
    TODO consider moving into a "Mixin" class to be mixed in with multiple inheritance by relevant ops only..
    """
    if stride_align == StrideAlignType.SE:  # all set, this is the default tensorflow "SAME" pad
        pass
    elif stride_align == StrideAlignType.NW:  # reverse the padding end<->begin!
        pad_beg_h, pad_end_h, pad_beg_w, pad_end_w = pad_end_h, pad_beg_h, pad_end_w, pad_beg_w
    else:
        raise AccelerasImplementationError(
            f"Unsupported spatial sampling grid alignment {stride_align} for strided conv/pool",
        )

    inp_p = tf.pad(
        inp,
        [[0, 0], [pad_beg_h, pad_end_h], [pad_beg_w, pad_end_w], [0, 0]],
        constant_values=tf.cast(input_zero_point, inp.dtype),
    )
    return inp_p
