#!/usr/bin/env python

import math

import numpy as np
from past.utils import old_div

from hailo_model_optimization.acceleras.utils.acceleras_definitions import BiasMode
from hailo_sdk_client.sdk_backend.sdk_backend_exceptions import BackendQuantizationException
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.model_params.model_params import ModelParams

MAX_NUM_REPEATS = 32
MAX_NUM_REPEATS_PARTIAL_NUMERIC = 8192
MAX_I8_U8_MULT = 127 * 255

UINT8_MAX_VAL = 2**8 - 1
INT16_MAX_VAL = 2**15 - 1

DEFAULT_BIT_SIZE = 8


def extract_limvals(stats_min, stats_max, sym_flag=False, prev_limvals=None):
    if sym_flag:
        # used for kernels, which are assumed to have zero point at midway between min,max
        max_val = np.max([np.abs(stats_max), np.abs(stats_min)])

        if prev_limvals is not None:
            max_val = np.max([max_val, prev_limvals[1]])
        limvals = (-max_val, max_val)
        limvals = np.float32(limvals)
    else:
        limvals = (stats_min, stats_max)
        if prev_limvals is not None:
            limvals = (np.min([limvals[0], prev_limvals[0]]), np.max([limvals[1], prev_limvals[1]]))
    return limvals


def quantize_multiplier(mult, mantbits):
    """
    the quantized multiplier is calculated by getting the fixed
    point representation of the real mult: (2**exp)*mantissa.
    we expect the multiples to be <=1 so the mantissa is in the range
    [0.5,1].
    The function returns:
    shift(int) - the size of right shift of the int16 accumulator
    mult_q(int size bits) - the int16 representation of the mantissa.

    if mult_q = 2**10 then we increase shift by 1 and divide mult_q by two
    example:
    if mult=0.4 then mantissa = 0.8, shift=1.
    mult_q = round(0.8*2**10) = 819
    now suppose the int16 accumulator = 10000. We want to get a number close to
    0.8*10000 = 8000. using integers it can had by mult_q*accumulator >> 10:
    819*10000 >> 10 = 7998
    """
    # NOTE: numpy's frexp returns mantissa in (0.5, 1],
    #       but in fact, FLOAT uses <mantbits> to represent numbers in (1, 2]
    man, shift = np.frexp(mult)
    d_shift = -shift

    # Note: we move 2**bits factor from shift to mantissa
    mult_q = np.ceil(man * 2**mantbits)
    d_shift += mantbits

    # deal with the case that mult_q is close to 1.0
    if mult_q == 2**mantbits:
        mult_q = 2 ** (mantbits - 1)
        d_shift -= 1

    return d_shift, mult_q


def a_b_factorize(target, max_a, max_b):
    """
    find min-error factorization of a number as multiplication of two integers,
     one ranged up to max_a, the other ranged up to max_b
    """
    assert target >= 0
    assert target <= max_a * max_b, "can't represent this number, too big for"
    target = np.float32(target)
    b_fac = np.arange(max(np.ceil(old_div(np.abs(target), max_a)), 1), max_b + 1)
    a_fac = np.round(old_div(np.abs(target), b_fac))
    error = np.abs(target - a_fac * b_fac)
    bind = np.argmin(error)
    return a_fac[bind], b_fac[bind]


def int_smallnum_factorize(target, bits=DEFAULT_BIT_SIZE, maxsmallnum=MAX_NUM_REPEATS):
    """
    find min-error factorization of a INT number (smaller than (2^(bits-1) - 1)*maxsmallnum by abs value)
    into integers multiplication I*R where I is int and R is  maxsmallnum or smaller
    """
    int_fac, smallnum = a_b_factorize(np.abs(target), 2 ** (bits - 1) - 1, maxsmallnum)
    return int_fac * np.sign(target), smallnum


def uint_smallnum_factorize(target, bits=DEFAULT_BIT_SIZE, maxsmallnum=MAX_NUM_REPEATS):
    """
    find min-error factorization of a UINT number (smaller than (2^bits - 1)*maxsmallnum)
    into integers multiplication U*R where U is uint and R is  maxsmallnum or smaller
    """
    max_a = 2**bits - 1
    max_b = maxsmallnum
    # add an explicit exeption
    if target > max_a * max_b:
        raise BackendQuantizationException(
            f"The number {target} can't be represented with {bits} bits and maximum {max_b} repeats.",
        )
    return a_b_factorize(target, max_a, max_b)


def uint_int_factorize(target, bits=DEFAULT_BIT_SIZE):
    """
    find min-error factorization of INT number (in case of INT16* - < 127*255) into U*I
    """
    max_int = 2 ** (bits - 1) - 1.0
    max_uint = 2**bits - 1.0
    max_int_uint_mul = max_int * max_uint
    assert abs(target) <= max_int_uint_mul, "can't represent this number, too big"
    target = np.float32(target)
    target_sign = np.sign(target) if target != 0 else 1
    int_fac = np.arange(max(np.ceil(np.abs(target) / max_uint), 1), max_int + 1) * target_sign
    uint_fac = np.round(np.abs(old_div(target, int_fac)))
    error = np.abs(target - uint_fac * int_fac)
    bind = np.argmin(error)
    return uint_fac[bind], int_fac[bind]


def rep_as_uint_x_int_repeats(
    vector,
    bits=DEFAULT_BIT_SIZE,
    is_shared_uint_fac=True,
    is_partial_numeric=False,
    bias_mode=BiasMode.single_scale_decomposition,
    max_feed_repeat=None,
):
    """
    represent all members of vector as R*U*I, where R is shared across the vector.
    TODO - add option for the <I> to be shared as well
    """
    vector = np.float32(vector)
    if not max_feed_repeat:
        max_feed_repeat = MAX_NUM_REPEATS_PARTIAL_NUMERIC if is_partial_numeric else MAX_NUM_REPEATS
    repeats = 1
    shared_u_fac = 1
    max_int = 2 ** (bits - 1) - 1.0
    max_uint = 2**bits - 1.0
    max_int_uint_mul = max_int * max_uint

    if is_shared_uint_fac:
        # TODO - add option for "partially shared" R and/or U (across a few groups/clusters of channels)
        # the only degree of freedom is the I, as R*U are both shared and fixed.
        # so, for best utilization  I=127 will represent the highest element of vector,
        # and the R*U calculated accordingly
        if bias_mode == BiasMode.double_scale_decomposition:
            shared_u_fac = (max_uint, 1)
            repeats = np.ceil(np.max(np.abs(vector)) / ((shared_u_fac[0] + shared_u_fac[1]) * max_int))
            int_vec_a = np.floor(vector / repeats / shared_u_fac[0])
            int_vec_b = (vector / repeats - int_vec_a * shared_u_fac[0]) / shared_u_fac[1]
            int_vec = (int_vec_a, int_vec_b)
            vector_as_represented = shared_u_fac[0] * int_vec_a + shared_u_fac[1] * int_vec_b

        # in case all the element in the bias are the same, we just can use a standard factorization
        elif np.all(vector == vector[0]):
            if vector[0] == 0:
                shared_u_fac = 0
                vector_as_represented = int_vec = vector
            else:
                repeats = np.ceil(old_div(np.abs(vector[0]), max_int_uint_mul))
                vector_2repeat = np.round(old_div(vector, repeats))
                shared_u_fac, int_val = uint_int_factorize(vector_2repeat[0], bits=bits)
                int_vec = np.float32(np.array(len(vector) * [int_val]))
                vector_as_represented = int_vec * shared_u_fac * repeats
        else:
            u_x_r = np.ceil(np.max(np.abs(vector)) / (2 ** (bits - 1) - 1.0))
            # at least we'll try to precisely represent the total factor as repeats X fixed_factor
            if u_x_r != 0:
                shared_u_fac, repeats = uint_smallnum_factorize(u_x_r, bits=bits, maxsmallnum=max_feed_repeat)
                int_vec = np.round(old_div(vector, (shared_u_fac * repeats * 1.0)))
                vector_as_represented = int_vec * (shared_u_fac * repeats * 1.0)
            else:
                vector_as_represented = int_vec = np.round(vector)
    else:
        repeats = np.ceil(old_div(np.max(np.abs(vector)), max_int_uint_mul))
        vector_2repeat = np.round(old_div(vector, repeats))
        vector_2repeat_factorized = [uint_int_factorize(t, bits=bits) for t in vector_2repeat]
        vector_as_represented = repeats * np.array([uint * int for uint, int in vector_2repeat_factorized])

    return repeats, vector_as_represented, int_vec, shared_u_fac


def get_bilinear_weights(name, inp, qp_in):
    half_pixels = inp["resize_bilinear_pixel_mode"] == "half_pixels"
    align_corners = inp["resize_bilinear_pixel_mode"] == "align_corners"
    h_ratio = inp["output_tensor"].shape[1] / inp["input_tensor"].shape[1]
    w_ratio = inp["output_tensor"].shape[2] / inp["input_tensor"].shape[2]
    h_out = inp["output_tensor"].shape[1]
    w_out = inp["output_tensor"].shape[2]
    h_in = inp["input_tensor"].shape[1]
    w_in = inp["input_tensor"].shape[2]
    w_factor = 2 * w_out if not align_corners or w_out == 1 else 2 * (w_out - 1)
    h_factor = 2 * h_out if not align_corners or h_out == 1 else 2 * (h_out - 1)
    if half_pixels:
        y1, x1 = round((0.5 + 0.5 / h_ratio) * h_factor), round((0.5 + 0.5 / w_ratio) * w_factor)
        y2, x2 = round((h_in + 0.5 - 0.5 / h_ratio) * h_factor), round((w_in + 0.5 - 0.5 / w_ratio) * w_factor)
    else:
        y1, x1 = h_factor, w_factor
        if align_corners:
            y2, x2 = h_in * h_factor, w_in * w_factor
        else:
            y2, x2 = round((h_in + 1 - (1 / h_ratio)) * h_factor), round((w_in + 1 - (1 / w_ratio)) * w_factor)
    y_range = np.linspace(y1, y2, h_out)
    x_range = np.linspace(x1, x2, w_out)
    if half_pixels:
        if int(w_ratio) == w_ratio:
            x_in_corner = int(w_ratio / 2)
        elif (x2 - x1) != 0:
            x_in_corner = math.ceil((w_factor - x1) / (w_factor / w_ratio))
        else:
            x_in_corner = 0
        if int(h_ratio) == h_ratio:
            y_in_corner = int(h_ratio / 2)
        elif (y2 - y1) != 0:
            y_in_corner = math.ceil((h_factor - y1) / (h_factor / h_ratio))
        else:
            y_in_corner = 0
    else:
        x_in_corner, y_in_corner = 0, 0
    xx_in, yy_in = np.meshgrid(x_range, y_range)
    # Calculate coordinates and residuals for bilinear
    xx_in_floor = xx_in // w_factor
    yy_in_floor = yy_in // h_factor
    xx_in = xx_in / w_factor
    yy_in = yy_in / h_factor
    r_x = np.reshape((xx_in - xx_in_floor), (h_out, w_out, 1, 1))
    r_y = np.reshape((yy_in - yy_in_floor), (h_out, w_out, 1, 1))
    if half_pixels:
        if w_ratio % 2 != 0 and w_ratio == int(w_ratio):
            r_x[:, x_in_corner :: int(w_ratio), :, :] = 0
        if h_ratio % 2 != 0 and h_ratio == int(h_ratio):
            r_y[y_in_corner :: int(h_ratio), :, :, :] = 0
        if x_in_corner != 0:
            r_x[:, : int(x_in_corner), :, :] = 1
            r_x[:, int(-(x_in_corner)) :, :, :] = 0
        if y_in_corner != 0:
            r_y[: int(y_in_corner), :, :, :] = 1
            r_y[int(-(y_in_corner)) :, :, :, :] = 0
    elif not half_pixels and not align_corners:
        r_x[:, int(-(math.floor(w_ratio))) :, :, :] = 0
        r_y[int(-(math.floor(h_ratio))) :, :, :, :] = 0
    one_m_rx = 1 - r_x
    one_m_ry = 1 - r_y

    kernel = np.concatenate((one_m_rx * one_m_ry, r_x * one_m_ry, one_m_rx * r_y, r_x * r_y), axis=2)
    bias = np.array([0])
    return bias, kernel


def get_weights(name, params, layer_type):
    if not isinstance(params, ModelParams):
        params = ModelParams(params)
    if layer_type == LayerType.bbox_decoder:
        kernel = np.array(
            [
                params[name]["anchors_heights"],
                params[name]["anchors_widths"],
                params[name]["anchors_heights_div_2"],
                params[name]["anchors_widths_div_2"],
                params[name]["anchors_heights_minus_div_2"],
                params[name]["anchors_widths_minus_div_2"],
            ],
        )
        bias = [params[name]["y_centers"], params[name]["x_centers"]]
    else:
        bias = params[name]["bias"]
        kernel = params[name]["kernel"]
    if "weights_clipping_values" in params[name]:
        values = np.array(params[name]["weights_clipping_values"])
        if layer_type in {LayerType.dw, LayerType.normalization}:
            values = np.expand_dims(values, axis=-1)
        else:
            values = np.expand_dims(values, axis=1)
        kernel_min = values[0]
        kernel_max = values[1]
        kernel = np.clip(kernel, kernel_min, kernel_max)
    return bias, kernel
