#!/usr/bin/env python
from functools import reduce
from math import gcd

import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import PrecisionMode
from hailo_sdk_common.hailo_nn.hn_definitions import NMSMetaArchitectures

QBITS = 8


class PostprocessCenterCalculationException(Exception):
    pass


def ceil_divide(dividend, divisor):
    return int(np.ceil(np.divide(float(dividend), divisor)))


def align_to(num, alignment):
    return ceil_divide(num, alignment) * alignment


def get_gcd(*numbers):
    """Returns the greatest-common-divider for a list of numbers"""
    return reduce(gcd, numbers)


def get_lcm(*numbers):
    """Returns the least-common-multiplier for a list of numbers"""

    def lcm(a, b):
        return (a * b) // get_gcd(a, b)

    return reduce(lcm, numbers[0], 1)


def primes(n):
    primfac = []
    d = 2
    while d * d <= n:
        while (n % d) == 0:
            primfac.append(d)  # supposing you want multiple factors repeated
            n //= d
        d += 1
    if n > 1:
        primfac.append(n)
    return primfac


def get_bbox_centers_for_ssd(grid_height, grid_width, num_of_anchors):
    anchor_strides = [1.0 / float(grid_height), 1.0 / float(grid_width)]
    # Create a vector with the offset to the middle of the pixel per spatial dimension
    anchor_offsets = [0.5 * anchor_strides[0], 0.5 * anchor_strides[1]]
    y_centers = np.array(list(range(grid_height)))
    x_centers = np.array(list(range(grid_width)))
    # Create tensors for the y_centers and x_centers
    y_centers = y_centers * anchor_strides[0] + anchor_offsets[0]
    x_centers = x_centers * anchor_strides[1] + anchor_offsets[1]
    y_centers = np.repeat(y_centers, 2 * num_of_anchors).reshape((grid_height, 2 * num_of_anchors))
    x_centers = np.repeat(x_centers, 2 * num_of_anchors).reshape((grid_width, 2 * num_of_anchors))
    return y_centers, x_centers


def get_bbox_centers_for_centernet(grid_height, grid_width, num_of_anchors):
    y_centers = np.array(list(range(grid_height)))
    x_centers = np.array(list(range(grid_width)))
    y_centers = np.repeat(y_centers, 2 * num_of_anchors).reshape((grid_height, 2 * num_of_anchors))
    x_centers = np.repeat(x_centers, 2 * num_of_anchors).reshape((grid_width, 2 * num_of_anchors))
    return y_centers / float(grid_height), x_centers / float(grid_width)


def get_bbox_centers_for_yolo(meta_arch, grid_height, grid_width, num_of_anchors):
    alloc_size = (grid_height, grid_width)
    y_centers = np.array(list(range(alloc_size[0])))
    x_centers = np.array(list(range(alloc_size[1])))
    y_centers = np.repeat(y_centers, 2 * num_of_anchors).reshape((alloc_size[0], 2 * num_of_anchors))
    x_centers = np.repeat(x_centers, 2 * num_of_anchors).reshape((alloc_size[1], 2 * num_of_anchors))

    # shifts centers according to YOLO version
    if meta_arch == NMSMetaArchitectures.YOLOV5:
        y_centers = y_centers - 0.5
        x_centers = x_centers - 0.5
    elif meta_arch == NMSMetaArchitectures.YOLOV6:
        x_centers = x_centers + 0.5
        y_centers = y_centers + 0.5
    else:
        PostprocessCenterCalculationException("Can't extract bbox centers for the specified architecture")

    return y_centers / float(grid_height), x_centers / float(grid_width)


def get_deconv_stack_order(kernel_shape, rate):
    if kernel_shape == 4 and rate == 4:
        return [
            (0, 0),
            (0, 2),
            (2, 0),
            (2, 2),
            (0, 1),
            (0, 3),
            (2, 1),
            (2, 3),
            (1, 0),
            (1, 2),
            (3, 0),
            (3, 2),
            (1, 1),
            (1, 3),
            (3, 1),
            (3, 3),
        ]
    elif kernel_shape == 8:
        return [
            (1, 1),
            (1, 3),
            (3, 1),
            (3, 3),
            (1, 0),
            (1, 2),
            (3, 0),
            (3, 2),
            (0, 1),
            (0, 3),
            (2, 1),
            (2, 3),
            (0, 0),
            (0, 2),
            (2, 0),
            (2, 2),
        ]
    elif kernel_shape == 16:
        return [
            (3, 3),
            (3, 7),
            (7, 3),
            (7, 7),
            (3, 1),
            (3, 5),
            (7, 1),
            (7, 5),
            (1, 3),
            (1, 7),
            (5, 3),
            (5, 7),
            (1, 1),
            (1, 5),
            (5, 1),
            (5, 5),
            (3, 2),
            (3, 6),
            (7, 2),
            (7, 6),
            (3, 0),
            (3, 4),
            (7, 0),
            (7, 4),
            (1, 2),
            (1, 6),
            (5, 2),
            (5, 6),
            (1, 0),
            (1, 4),
            (5, 0),
            (5, 4),
            (2, 3),
            (2, 7),
            (6, 3),
            (6, 7),
            (2, 1),
            (2, 5),
            (6, 1),
            (6, 5),
            (0, 3),
            (0, 7),
            (4, 3),
            (4, 7),
            (0, 1),
            (0, 5),
            (4, 1),
            (4, 5),
            (2, 2),
            (2, 6),
            (6, 2),
            (6, 6),
            (2, 0),
            (2, 4),
            (6, 0),
            (6, 4),
            (0, 2),
            (0, 6),
            (4, 2),
            (4, 6),
            (0, 0),
            (0, 4),
            (4, 0),
            (4, 4),
        ]
    else:
        return [(0, 0), (0, 1), (1, 0), (1, 1)]


def is_super_deconv(layer):
    if layer.precision_config.precision_mode in [
        PrecisionMode.a16_w16,
        PrecisionMode.a16_w16_a16,
        PrecisionMode.a16_w16_a8,
    ]:
        return True

    kernel_h, kernel_w, rate_h, rate_w = layer.kernel_height, layer.kernel_width, layer.strides[1], layer.strides[2]

    if layer.groups > 1 or (
        layer.precision_config.quantization_groups is int and layer.precision_config.quantization_groups > 1
    ):
        return False

    return not (
        (kernel_h == 4 and kernel_w == 4 and rate_h == 2 and rate_w == 2)
        or (kernel_h == 4 and kernel_w == 4 and rate_h == 4 and rate_w == 4)
        or (kernel_h == 2 and kernel_w == 2 and rate_h == 2 and rate_w == 2)
        or (kernel_h == 8 and kernel_w == 8 and rate_h == 4 and rate_w == 4)
        or (kernel_h == 16 and kernel_w == 16 and rate_h == 8 and rate_w == 8)
    )
