import numpy as np
from scipy.linalg import block_diag, hadamard
from scipy.stats import ortho_group

from hailo_model_optimization.acceleras.utils.acceleras_definitions import OrthoGenType


def _get_hadamard_rotation_matrix(size: int, seed: int = 1234) -> np.ndarray:
    """
    Generate orthonormal matrix using Hadamard matrix of size `size`.
    """
    return hadamard(size, dtype=np.float32) / np.sqrt(size)


def _get_random_rotation_matrix(size: int, seed: int = 1234) -> np.ndarray:
    """
    Generate a random orthonormal matrix of size `size`.
    """
    return ortho_group.rvs(size, random_state=seed)


def _get_partial_random_rotation_matrix(size: int, seed: int = 1234) -> np.ndarray:
    """
    Generate a random orthonormal matrix of size `size` such that the first channel represent the sum of the
    original vector divided by sqrt(size).

    Examples
    --------
    >>> import numpy as np
    >>> A = np.array([[2, 4, -1], [-3, 5, -1]])
    >>> Q = _get_partial_random_rotation_matrix(3)
    >>> Q
    array([[ 0.57735027, -0.32432726, -0.74931869],
            [ 0.57735027,  0.81109265,  0.09378369],
            [ 0.57735027, -0.48676539,  0.65553499]])

    >>> (A @ Q)[:, 0]
    array([2.88675135, 0.57735027])

    >>> np.sum(A, axis=-1) / np.sqrt(3)
    array([2.88675135, 0.57735027])
    """
    projection = np.block(
        [
            [np.ones((1, size)) / np.sqrt(size)],
            [np.ones((size - 1, 1)) / np.sqrt(size), np.eye(size - 1) - 1 / (size - np.sqrt(size))],
        ]
    )
    rotation = block_diag([[1]], _get_random_rotation_matrix(size - 1, seed))
    return projection @ rotation


def get_rotation_matrix(
    size: int,
    *args,
    ortho_gen_type: OrthoGenType = OrthoGenType.PARTIAL_RANDOM,
    groups: int = 1,
    seed: int = 1234,
) -> np.ndarray:
    """
    Generate a orthonormal matrix of shape `(size, size)`.

    Args:
        size (int): The dimantion of the matrix
        ortho_gen_type (OrthoGenType, optional): Generation type of the matrix. Defaults to OrthoGenType.PARTIAL_RANDOM.
        groups (int, optional): Number of groups of the matrix. Defaults to 1.
        seed (int, optional): The seed to generate the matrix with. Defaults to 1234.

    Returns:
        np.ndarray: orthonormal matrix of shape `(size, size)`.
    """
    ortho_gen_callback = {
        OrthoGenType.HADAMARD: _get_hadamard_rotation_matrix,
        OrthoGenType.RANDOM: _get_random_rotation_matrix,
        OrthoGenType.PARTIAL_RANDOM: _get_partial_random_rotation_matrix,
    }[ortho_gen_type]
    block_size = size // groups
    block = ortho_gen_callback(block_size, seed=seed)
    return block_diag(*[block for _ in range(groups)])
