from functools import partial

import numpy as np
import tensorflow as tf
from pydantic import BaseModel, Field

from hailo_model_optimization.acceleras.utils.acceleras_definitions import FormatConversionType


class ConvertionWeightsDataStruct(BaseModel):
    # FormatConversionType.rotation
    rotation: np.ndarray = Field(None, description="Rotation matrix for rotation format convertion")

    # FormatConversionType.mask, FormatConversionType.cos, FormatConversionType.sin
    tile: np.ndarray = Field(None, description="Tile count for the generated input. Will be used for mask, cos and sin")

    # FormatConversionType.cos, FormatConversionType.sin
    factor: float = Field(None, description="Factor to multiply the generated input with. Will be used for cos and sin")
    theta: np.ndarray = Field(None, description="Angle per feature for cos and sin")

    # FormatConversionType.embedding
    embed: np.ndarray = Field(None, description="Embedding matrix for embedding format convertion")

    class Config:
        arbitrary_types_allowed = True
        extra = "forbid"


def get_conversion_callback(conversion_type: FormatConversionType, conversion_weights: ConvertionWeightsDataStruct):
    conversion_factory = {
        FormatConversionType.yuy2_to_hailo_yuv: yuy2_to_yuv_conversion,
        FormatConversionType.tf_rgbx_to_hailo_rgb: rgbx_to_rgb_conversion,
        FormatConversionType.nv12_to_hailo_yuv: nv12_to_yuv_conversion,
        FormatConversionType.nv21_to_hailo_yuv: nv21_to_yuv_conversion,
        FormatConversionType.i420_to_hailo_yuv: i420_to_yuv_conversion,
        FormatConversionType.rotation: partial(rotation_conversion, rotation_matrix=conversion_weights.rotation),
        FormatConversionType.mask: partial(mask_generation, tile=conversion_weights.tile),
        FormatConversionType.cos: partial(
            cos_sin_generation,
            function=tf.cos,
            tile=conversion_weights.tile,
            factor=conversion_weights.factor,
            theta=conversion_weights.theta,
        ),
        FormatConversionType.sin: partial(
            cos_sin_generation,
            function=tf.sin,
            tile=conversion_weights.tile,
            factor=conversion_weights.factor,
            theta=conversion_weights.theta,
        ),
        FormatConversionType.embedding: partial(embedding_conversion, embed=conversion_weights.embed),
    }

    return conversion_factory[conversion_type]


def yuy2_to_yuv_conversion(inputs: tf.Tensor, input_shapes) -> tf.Tensor:
    yuy2_quadruplets = tf.reshape(inputs, (-1, int(inputs.shape[1] * inputs.shape[2] / 2), 4))
    u_tensor = tf.repeat(yuy2_quadruplets[:, :, 1], 2, axis=1)
    v_tensor = tf.repeat(yuy2_quadruplets[:, :, 3], 2, axis=1)
    y0_tensor = yuy2_quadruplets[:, :, 0]
    y1_tensor = yuy2_quadruplets[:, :, 2]
    y_tensor = tf.reshape(tf.stack([y0_tensor, y1_tensor], axis=2), [-1, yuy2_quadruplets.shape[1] * 2])
    # creating the YUV tensors in channels first:
    yuv = tf.stack((y_tensor, u_tensor, v_tensor), axis=1)
    yuv = tf.reshape(yuv, (-1, 3, input_shapes[1], input_shapes[2]))
    # converting to channels last:
    return tf.transpose(yuv, [0, 2, 3, 1])


def rgbx_to_rgb_conversion(inputs: tf.Tensor, input_shapes) -> tf.Tensor:
    return inputs[:, :, :, :3]


def nv12_to_yuv_conversion(inputs: tf.Tensor, input_shapes) -> tf.Tensor:
    return nv_to_yuv_conversion(inputs, input_shapes, FormatConversionType.nv12_to_hailo_yuv)


def nv21_to_yuv_conversion(inputs: tf.Tensor, input_shapes) -> tf.Tensor:
    return nv_to_yuv_conversion(inputs, input_shapes, FormatConversionType.nv21_to_hailo_yuv)


def i420_to_yuv_conversion(inputs: tf.Tensor, input_shapes) -> tf.Tensor:
    return nv_to_yuv_conversion(inputs, input_shapes, FormatConversionType.i420_to_hailo_yuv)


def nv_to_yuv_conversion(inputs: tf.Tensor, input_shapes, conversion_type) -> tf.Tensor:
    three_streams = tf.reshape(inputs, (-1, 3, int(input_shapes[1] * input_shapes[2] / 2)))
    y1 = three_streams[:, 0]
    y2 = three_streams[:, 1]
    uv = three_streams[:, 2]

    y_channel = tf.reshape(tf.concat([y1, y2], axis=1), (-1, input_shapes[1], input_shapes[2]))
    even_y_rows = tf.reshape(y_channel[:, 0::2], [-1, int(input_shapes[1] * input_shapes[2] / 2)])
    odd_y_rows = tf.reshape(y_channel[:, 1::2], [-1, int(input_shapes[1] * input_shapes[2] / 2)])
    two_rows_y_channel = tf.expand_dims(tf.stack((even_y_rows, odd_y_rows), axis=1), axis=3)

    # The formats differ in the u, v order
    if conversion_type == FormatConversionType.nv12_to_hailo_yuv:
        raw_u_v = tf.stack([uv[:, ::2], uv[:, 1::2]], axis=2)
    elif conversion_type == FormatConversionType.nv21_to_hailo_yuv:
        raw_u_v = tf.stack([uv[:, 1::2], uv[:, ::2]], axis=2)
    elif conversion_type == FormatConversionType.i420_to_hailo_yuv:
        raw_u_v = tf.stack([uv[:, : uv.shape[1] // 2], uv[:, uv.shape[1] // 2 :]], axis=2)

    u_v = tf.repeat(raw_u_v, 2, axis=1)
    u_v = tf.expand_dims(u_v, axis=1)
    two_rows_u_v_channels = tf.repeat(u_v, 2, axis=1)
    two_rows_all_channels = tf.concat([two_rows_y_channel, two_rows_u_v_channels], axis=3)
    two_rows_chunks = tf.split(two_rows_all_channels, two_rows_all_channels.shape[2] // input_shapes[2], axis=2)
    return tf.concat(two_rows_chunks, axis=1)


def rotation_conversion(inputs: tf.Tensor, input_shapes, rotation_matrix) -> tf.Tensor:
    return inputs @ tf.constant(rotation_matrix, dtype=inputs.dtype)


def mask_generation(inputs: tf.Tensor, input_shapes, tile) -> tf.Tensor:
    inputs = tf.cast(inputs, tf.int32)
    input_ids_size = input_shapes[-2] // tile[-2]
    cache_size = input_shapes[-1] // tile[-1]
    attention_mask = tf.reshape(
        np.tri(input_ids_size, cache_size, cache_size - input_ids_size, dtype=bool), [1, 1, input_ids_size, cache_size]
    )
    history_mask = cache_size - tf.reshape(inputs, [-1, 1, 1, 1]) <= tf.reshape(
        tf.range(cache_size, dtype=inputs.dtype), [1, 1, 1, -1]
    )
    padding_mask = input_ids_size - tf.reshape(inputs, [-1, 1, 1, 1]) > tf.reshape(
        tf.range(input_ids_size, dtype=inputs.dtype), [1, 1, -1, 1]
    )
    mask = tf.logical_or(tf.logical_and(attention_mask, history_mask), padding_mask)
    return tf.tile(tf.cast(mask, tf.float32), tf.concat([[1], tile], axis=0))


def cos_sin_generation(inputs: tf.Tensor, input_shapes, function, tile, factor, theta) -> tf.Tensor:
    inputs = tf.cast(inputs, tf.float64)
    angles = tf.cast(tf.reshape(theta, [-1, 1, 1, input_shapes[-1] // tile[-1]]), inputs.dtype)
    time_step = tf.reshape(inputs, [-1, 1, 1, 1]) - tf.reshape(
        tf.range(input_shapes[-2] // tile[-2], 0, -1, dtype=inputs.dtype), [1, 1, -1, 1]
    )
    return tf.tile(factor * function(angles * time_step), tf.concat([[1], tile], axis=0))


def embedding_conversion(inputs: tf.Tensor, input_shapes, embed) -> tf.Tensor:
    return tf.gather(tf.constant(embed), tf.cast(inputs, dtype=tf.int32), axis=0)
