#!/usr/bin/env python

from enum import Enum

import numpy as np


class InvalidDatasetException(Exception):
    pass


class ColorType(Enum):
    yuv = "yuv"
    yuv_full_range = "yuv_full_range"
    yuv601 = "yuv601"
    yuv709 = "yuv709"
    rgb = "rgb"
    bgr = "bgr"


def translate_rgb_dataset(rgb_dataset, color_type=ColorType.yuv):
    """
    Translate a given RGB format images dataset to YUV or BGR format images. This function is useful
    when the model expects YUV or BGR images, while the calibration images used for quantization
    are in RGB.

    Args:
        rgb_dataset (``numpy.ndarray``): Numpy array of RGB format images with shape
            ``(image_count, h, w, 3)`` to translate.
        color_type (:class:`~ColorType`): type of color to translate the data to. Defaults to yuv.

    """
    if len(rgb_dataset.shape) < 4 or rgb_dataset.shape[3] != 3:
        raise InvalidDatasetException("The given dataset must be in RGB format (3 features)")

    if color_type == ColorType.yuv_full_range:
        transition_matrix = np.array(
            [
                [0.299, -0.169, 0.5],
                [0.587, -0.331, -0.419],
                [0.114, 0.5, -0.081],
            ],
        )
    elif color_type in [ColorType.yuv, ColorType.yuv601]:
        transition_matrix = np.array(
            [
                [0.2568619, -0.14823364, 0.43923104],
                [0.5042455, -0.2909974, -0.367758],
                [0.09799913, 0.43923104, -0.07147305],
            ],
        )
    elif color_type == ColorType.yuv709:
        transition_matrix = np.array(
            [
                [0.183, -0.101, 0.439],
                [0.614, -0.339, -0.399],
                [0.062, 0.439, -0.040],
            ],
        )
    elif color_type == ColorType.bgr:
        transition_matrix = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]])

    translated_dataset = np.zeros(rgb_dataset.shape)

    for index, image in enumerate(rgb_dataset):
        image = np.dot(image, transition_matrix)
        if color_type == ColorType.yuv_full_range:
            image += np.array([0, 128, 128])
        elif color_type in [ColorType.yuv, ColorType.yuv601, ColorType.yuv709]:
            image += np.array([16, 128, 128])
        translated_dataset[index, :, :, :] = image

    return translated_dataset
