import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ColorConversionType,
    FormatConversionType,
    HybridConversionType,
)
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType, PaddingType
from hailo_sdk_common.hailo_nn.hn_layers import FormatConversionLayer, FusedConv2DLayer
from hailo_sdk_common.hailo_nn.hn_layers.layer_common import MODEL_SCRIPT_LAYER_PREFIX
from hailo_sdk_common.logger.logger import default_logger


class InputConversionLayersAdditionException(Exception):
    pass


class InputConversionLayersAdder:
    """
    Representing object responsible for adding layers translating inputs to RGB.

    Args:
        hailo_nn (:class:`~hailo_sdk_common.hailo_nn.hailo_nn.HailoNN`): The Hailo NN to add the conversion layers.
        params (:class:`~hailo_sdk_common.model_params.model_params.ModelParams`): The params to add conversion layer
            params.
        conversion_type (:class:`~hailo_model_optimization.acceleras.utils.acceleras_definitions.ColorConversionType'
            or :class:`~hailo_model_optimization.acceleras.utils.acceleras_definitions.FormatConversion'): color conversion type to translate the input from.
        input_names_to_translate (list of str, optional): List of input layers names to add translation layers,
            default to all input layers.
        conversion_layers (list of str, optional): Names for the new layers.

    """

    HYBRID_CONVERSION_TO_FORMAT_CONVERSION = {
        HybridConversionType.nv12_to_rgb: FormatConversionType.nv12_to_hailo_yuv,
        HybridConversionType.nv21_to_rgb: FormatConversionType.nv21_to_hailo_yuv,
        HybridConversionType.yuy2_to_rgb: FormatConversionType.yuy2_to_hailo_yuv,
        HybridConversionType.i420_to_rgb: FormatConversionType.i420_to_hailo_yuv,
    }
    YUV2_TO_HAILO_YUV_OUTPUT_WIDTH_FACTOR = 16

    def __init__(
        self,
        hailo_nn,
        params,
        conversion_type,
        input_names_to_translate,
        conversion_layers,
        in_emulation_graph,
    ):
        self._hailo_nn = hailo_nn
        self._params = params
        self._input_layers_to_translate = self._get_input_names_to_translate(input_names_to_translate)
        self._logger = default_logger()
        self._conversion_type = conversion_type
        self._get_conversion_names(conversion_layers)
        self._in_emulation_graph = in_emulation_graph

    @property
    def input_layers_to_translate(self):
        return self._input_layers_to_translate

    @property
    def conversion_type(self):
        return self._conversion_type

    @property
    def conversion_layers(self):
        return self._conversion_layers

    def _get_conversion_names(self, conversion_layers):
        num_of_conversion_layers = len(self._input_layers_to_translate)
        if self._conversion_type in HybridConversionType:
            num_of_conversion_layers *= 2

        if not conversion_layers:
            self._conversion_layers = [
                f"{self._conversion_type.value}{index + 1}" for index in range(num_of_conversion_layers)
            ]
        elif len(conversion_layers) == num_of_conversion_layers:
            self._conversion_layers = conversion_layers
        else:
            raise InputConversionLayersAdditionException(
                f"Given {len(conversion_layers)} names for the new layers when {num_of_conversion_layers} names are "
                f"required",
            )

    def add_conversion_layers(self):
        """
        Add layers translating inputs to RGB.
        """
        self._logger.verbose(f"Adding {self._conversion_type} layers")

        for index, input_layer in enumerate(self._input_layers_to_translate):
            if input_layer.output_features != 3:
                raise InputConversionLayersAdditionException(
                    f"{input_layer.name} has {input_layer.output_features} instead of 3 (for "
                    f"{self._conversion_type.value})",
                )
            if (
                self._conversion_type == FormatConversionType.yuy2_to_hailo_yuv
                and input_layer.output_width % self.YUV2_TO_HAILO_YUV_OUTPUT_WIDTH_FACTOR != 0
            ):
                raise InputConversionLayersAdditionException(
                    f"The output width of {input_layer.name} layer must be a multiple of {self.YUV2_TO_HAILO_YUV_OUTPUT_WIDTH_FACTOR} when converting from YUY2 to Hailo YUV"
                )
            if self._conversion_type in HybridConversionType:
                color_conversion_layer = self._create_conversion_layer(
                    name=self._conversion_layers[2 * index + 1],
                    conversion_type=ColorConversionType.yuv_to_rgb,
                    in_emulation_graph=True,
                )
                self._hailo_nn.push_layer(color_conversion_layer, [input_layer])
                conversion_layer_name = self._conversion_layers[2 * index]
                conversion_type = self.HYBRID_CONVERSION_TO_FORMAT_CONVERSION[self._conversion_type]
            else:
                conversion_layer_name = self._conversion_layers[index]
                conversion_type = self._conversion_type

            conversion_layer = self._create_conversion_layer(
                conversion_layer_name,
                conversion_type,
                self._in_emulation_graph,
            )
            input_layer.input_shapes = self._get_input_layer_shapes(input_layer, conversion_layer)
            input_layer.output_shapes = input_layer.input_shapes
            self._hailo_nn.push_layer(conversion_layer, [input_layer])

        return self._hailo_nn, self._params

    def _set_color_conversion_layer_params(self, conversion_layer, conversion_type):
        if conversion_type == ColorConversionType.yuv_full_range_to_rgb:
            self._params[f"{conversion_layer.name}/kernel:0"] = np.array(
                [[[[1, 1, 1], [0, -0.343, 1.765], [1.4, -0.711, 0]]]],
            )
            self._params[f"{conversion_layer.name}/bias:0"] = np.array([-179.2, 134.912, -225.92])
        elif conversion_type in [ColorConversionType.yuv_to_rgb, ColorConversionType.yuv601_to_rgb]:
            self._params[f"{conversion_layer.name}/kernel:0"] = np.array(
                [[[[1.164, 1.164, 1.164], [0, -0.392, 2.017], [1.596, -0.813, 0]]]],
            )
            self._params[f"{conversion_layer.name}/bias:0"] = np.array([-222.912, 135.616, -276.8])
        elif conversion_type == ColorConversionType.yuv709_to_rgb:
            self._params[f"{conversion_layer.name}/kernel:0"] = np.array(
                [[1.164, 1.164, 1.164], [0, -0.213, 2.112], [1.793, -0.533, 0]]
            )
            self._params[f"{conversion_layer.name}/bias:0"] = np.array([-248.128, 76.864, -288.96])
        elif conversion_type in [ColorConversionType.bgr_to_rgb, ColorConversionType.rgb_to_bgr]:
            self._params[f"{conversion_layer.name}/kernel:0"] = np.array([[[[0, 0, 1], [0, 1, 0], [1, 0, 0]]]])
            self._params[f"{conversion_layer.name}/bias:0"] = np.array([0, 0, 0])
        elif conversion_type == ColorConversionType.yuv_full_range_to_bgr:
            self._params[f"{conversion_layer.name}/kernel:0"] = np.array(
                [[[[1, 1, 1], [1.765, -0.343, 0], [0, -0.711, 1.4]]]],
            )
            self._params[f"{conversion_layer.name}/bias:0"] = np.array([-225.92, 134.912, -179.2])
        elif conversion_type in [ColorConversionType.yuv_to_bgr, ColorConversionType.yuv601_to_bgr]:
            self._params[f"{conversion_layer.name}/kernel:0"] = np.array(
                [[[[1.164, 1.164, 1.164], [2.017, -0.392, 0], [0, -0.813, 1.596]]]],
            )
            self._params[f"{conversion_layer.name}/bias:0"] = np.array([-276.8, 135.616, -222.912])
        elif conversion_type == ColorConversionType.yuv709_to_bgr:
            self._params[f"{conversion_layer.name}/kernel:0"] = np.array(
                [[[[1.164, 1.164, 1.164], [2.112, -0.213, 0], [0, -0.533, 1.793]]]],
            )
            self._params[f"{conversion_layer.name}/bias:0"] = np.array([-288.96, 76.864, -248.128])

    def _create_conversion_layer(self, name, conversion_type, in_emulation_graph):
        if conversion_type in ColorConversionType:
            conversion_layer = FusedConv2DLayer()
            conversion_layer.name = name
            conversion_layer.kernel_shape = [1, 1, 3, 3]
            conversion_layer.strides = [1, 1, 1, 1]
            conversion_layer.dilations = [1, 1, 1, 1]
            conversion_layer.padding = PaddingType.same
            conversion_layer.activation = ActivationType.linear
            self._set_color_conversion_layer_params(conversion_layer, conversion_type)
        else:
            conversion_layer = FormatConversionLayer()
            conversion_layer.name = name
            conversion_layer.conversion_type = conversion_type

        conversion_layer.in_emulation_graph = in_emulation_graph
        conversion_layer.add_original_name(
            f"{MODEL_SCRIPT_LAYER_PREFIX}_{conversion_layer.name_without_scope}",
        )

        return conversion_layer

    def _get_input_layer_shapes(self, input_layer, conversion_layer):
        output_shape = input_layer.input_shape.copy()
        if (
            self._conversion_type in ColorConversionType
            or conversion_layer.conversion_type == FormatConversionType.tf_rgb_to_hailo_rgb
        ):
            pass
        elif conversion_layer.conversion_type == FormatConversionType.tf_rgbx_to_hailo_rgb:
            output_shape[-1] += 1
        elif conversion_layer.conversion_type == FormatConversionType.yuy2_to_hailo_yuv:
            output_shape[-1] = 2
        elif conversion_layer.is_nv_converter():
            output_shape[1] //= 2
        else:
            raise InputConversionLayersAdditionException(f"Unknown input conversion type {self._conversion_type}")

        return output_shape

    def _get_input_names_to_translate(self, input_names_to_translate):
        input_layers = self._hailo_nn.get_input_layers()
        if not input_names_to_translate:
            return input_layers

        input_names_to_translate = [
            self._hailo_nn.get_layer_by_name(input_name) for input_name in input_names_to_translate
        ]

        invalid_input_names_to_translate = [
            input_name_to_translate.name
            for input_name_to_translate in input_names_to_translate
            if input_name_to_translate not in input_layers
        ]
        if invalid_input_names_to_translate:
            raise InputConversionLayersAdditionException(
                f"{invalid_input_names_to_translate} are not valid inputs names",
            )

        return input_names_to_translate


def translate_dataset_to_rgb(dataset, color_type):
    assert dataset.shape[3] == 3
    if color_type in [ColorConversionType.yuv_to_rgb, ColorConversionType.yuv601_to_rgb]:
        transition_matrix = np.array([[1.164, 1.164, 1.164], [0, -0.392, 2.017], [1.596, -0.813, 0]])
    if color_type == ColorConversionType.yuv709_to_rgb:
        transition_matrix = np.array([[1.164, 1.164, 1.164], [0, -0.213, 2.112], [1.793, -0.533, 0]])
    elif color_type in [ColorConversionType.bgr_to_rgb, ColorConversionType.rgb_to_bgr]:
        transition_matrix = np.array([[[[0, 0, 1], [0, 1, 0], [1, 0, 0]]]])

    rgb_dataset = np.zeros(dataset.shape)

    for index, image in enumerate(dataset):
        image = np.dot(image, transition_matrix)
        if color_type in [ColorConversionType.yuv_to_rgb, ColorConversionType.yuv601_to_rgb]:
            image += np.array([-222.912, 135.616, -276.8])
        if color_type == ColorConversionType.yuv709_to_rgb:
            image += np.array([-248.128, 76.864, -288.96])
        rgb_dataset[index, :, :, :] = image

    return rgb_dataset
