from collections import namedtuple

import numpy as np

from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType, PaddingType, TemporaryPaddingType


class CalculateShapeException(Exception):
    pass


# The initial index assigned to Layer objects:
# * If a Layer's index == INDEX_NOT_SET, it's order relative to other nodes in
#   a graph will be determined by next_insertion_order
# * We can't use None for this value, as it will cause comparison issues in python3
#   (comparing NoneType to NoneType or NoneType to int, both aren't allowed).
INDEX_NOT_SET = -1

MODEL_SCRIPT_LAYER_PREFIX = "model_script_layer"


def get_act_short_description(layer):
    if layer.activation != ActivationType.linear:
        return f" +{layer.activation.value.capitalize()}"
    return ""


def get_groups_short_description(layer):
    if layer.groups > 1:
        return f" (groups={layer.groups})"
    return ""


def input_to_output_height_width(input_shape, kernel_shape, strides, padding, dilations=None):
    output_shapes = []
    output_h, output_w = input_shape[1:3]
    dilations = [1, 1] if dilations is None else dilations[1:3]
    kernel_shape = kernel_shape[:2] if len(kernel_shape) == 4 else kernel_shape
    if padding == PaddingType.deconv:
        if (kernel_shape[0] != 1) or (kernel_shape[1] != 1):
            # we don't want stride to reduce dimension in deconv. only to act as rate.
            output_h = int(np.ceil(output_h + 1))
            output_w = int(np.ceil(output_w + 1))
        return output_h, output_w

    if padding in [PaddingType.valid, TemporaryPaddingType.external_undecided, TemporaryPaddingType.conv3d]:
        pads = [0, 0]
        round_mode = np.floor
    elif padding in [PaddingType.same, PaddingType.same_tensorflow, TemporaryPaddingType.same_lower]:
        # when we classified the padding type as same we take the ceil mode of the formula below
        pads = [
            (dilation * (kernel - 1) + 1 - stride) / 2
            for kernel, stride, dilation in zip(kernel_shape, strides[1:3], dilations)
        ]
        round_mode = np.ceil
    else:
        raise CalculateShapeException(f"Unsupported padding type {padding.value}")

    for dim, kernel, pad, stride, dilation in zip(input_shape[1:3], kernel_shape, pads, strides[1:3], dilations):
        output_shapes.append(int(round_mode(((dim + int(2 * pad) - dilation * (kernel - 1) - 1) / stride) + 1)))

    return (*output_shapes,)


BatchNormValues = namedtuple(
    "BatchNormValues",
    ["moving_mean", "moving_variance", "beta", "gamma", "epsilon"],
)
