from hailo_sdk_client.sdk_backend.sdk_backend_exceptions import BackendNotImplementedError
from hailo_sdk_common.hailo_nn.hn_definitions import PaddingType


def calculate_padding_per_dim(padding_type, kernel, strides, input, dil):
    dilation_kernel = kernel + (kernel - 1) * (dil - 1)
    if padding_type in (PaddingType.same, PaddingType.same_tensorflow, PaddingType.deconv):
        if input % strides == 0:
            pad_total = max(dilation_kernel - strides, 0)
        else:
            pad_total = max(dilation_kernel - (input % strides), 0)
        if padding_type == PaddingType.deconv:
            # In deconv adding even padding (and later remove wisely extra pixels)
            if pad_total % 2 != 0:
                pad_total += 1
        # First assume SAME_TENSORFLOW padding, this padding is done by tensorflow's SAME
        pad_beg = pad_total // 2
        pad_end = pad_total - pad_beg

        if padding_type == PaddingType.same:
            # This is a common alternative padding scheme chosen in HN
            pad_beg, pad_end = pad_end, pad_beg
    elif padding_type == PaddingType.valid:
        pad_beg = 0
        pad_end = 0

    else:
        raise BackendNotImplementedError(f"Unsupported padding type {padding_type}")

    return pad_beg, pad_end


def calculate_padding(padding_type, kernel_h, kernel_w, strides_h, strides_w, input_h, input_w, dilations=None):
    dil_h, dil_w = [1, 1] if dilations is None else dilations[1:3]
    pad_beg_h, pad_end_h = calculate_padding_per_dim(padding_type, kernel_h, strides_h, input_h, dil_h)
    pad_beg_w, pad_end_w = calculate_padding_per_dim(padding_type, kernel_w, strides_w, input_w, dil_w)

    return pad_beg_h, pad_end_h, pad_beg_w, pad_end_w
