from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ColorConversionType,
    EmulationSupportedConversions,
    HybridConversionType,
    InputConversionTypes,
)
from hailo_sdk_client.sdk_backend.modification_config import InputConversionConfig
from hailo_sdk_client.sdk_backend.script_parser.commands import SupportedCommands
from hailo_sdk_client.sdk_backend.script_parser.model_modifications_commands import (
    ModelModificationsOnInputLayerCommand,
)
from hailo_sdk_client.sdk_backend.sdk_backend_exceptions import AllocatorScriptParserException
from hailo_sdk_client.tools.input_conversion_layers_addition import InputConversionLayersAdder
from hailo_sdk_common.logger.logger import default_logger


class InputConversionCommand(ModelModificationsOnInputLayerCommand):
    """
    Add input conversion layers.

    Args:
        input_layer (str, optional): Input layer name to add input conversion layer after,
            default to all input layers.
        conversion_layers (list of str): List of names for the conversion layers that will be added.
        conversion_type (:class:`~hailo_sdk_common.hailo_nn.hn_definitions.ColorConversionType` or
                         :class:`~hailo_sdk_common.hailo_nn.hn_definitions.FormatConversionType`):
        in_emulation_graph (bool): whether or not to include the conversion in the emulation graph.

    """

    def __init__(self, input_layer, conversion_layers, conversion_type, in_emulation_graph):
        super().__init__(SupportedCommands.INPUT_CONVERSION)
        self._input_layer = input_layer
        self._conversion_layers = conversion_layers
        self._conversion_type = conversion_type
        self._in_emulation_graph = in_emulation_graph

    def __str__(self):
        input_layer = f"{self._input_layer}, " if self._input_layer else ""
        conversion_layers = ", ".join(self._conversion_layers)
        return f"{conversion_layers} = {self.function_name.value}({input_layer}{self._conversion_type.value}, emulator_support={self._in_emulation_graph})"

    @property
    def conversion_type(self):
        return self._conversion_type

    @property
    def in_emulation_graph(self):
        return self._in_emulation_graph

    @classmethod
    def from_tokens(cls, tokens):
        conversion_layers = tokens.multiple_return_vals.asList()
        input_layer = None

        args = tokens.function_args.asList()
        if len(args) == 1 or (len(args) == 2 and isinstance(args[-1], dict)):
            conversion_type = InputConversionTypes[args[0]]
        else:
            input_layer = args[0]
            conversion_type = InputConversionTypes[args[1]]
        in_emulation_graph = cls.is_in_emulation_graph(conversion_type, tokens)

        return cls(input_layer, conversion_layers, conversion_type, in_emulation_graph)

    def get_layers(self):
        return [self._input_layer]

    def validate_command(self, layers_scope_from_hn):
        if self._input_layer:
            if self._input_layer not in layers_scope_from_hn:
                raise AllocatorScriptParserException(f"Given layer {self._input_layer} not exist in the HN")
            if (self._conversion_type not in HybridConversionType and len(self._conversion_layers) != 1) or (
                self._conversion_type in HybridConversionType and len(self._conversion_layers) != 2
            ):
                raise AllocatorScriptParserException(
                    f"Given {len(self._conversion_layers)} names for the new layer when one name is required",
                )
        invalid_layer_names = [
            layer_name for layer_name in self._conversion_layers if layer_name in layers_scope_from_hn
        ]
        if invalid_layer_names:
            raise AllocatorScriptParserException(
                f"Given layer names {invalid_layer_names} exist in the model. Please use different names",
            )

    def add_scope(self, scope_names, force=False):
        if self._input_layer:
            self._input_layer = self.add_scope_to_layer(scope_names, self._input_layer, force)
        self._conversion_layers = [
            self.add_scope_to_layer(scope_names, conversion_layer, force)
            for conversion_layer in self._conversion_layers
        ]

    def apply(self, hailo_nn, params, **kwargs):
        input_names_to_translate = [self._input_layer] if self._input_layer else None
        conversion_adder = InputConversionLayersAdder(
            hailo_nn=hailo_nn,
            params=params,
            input_names_to_translate=input_names_to_translate,
            conversion_type=self._conversion_type,
            conversion_layers=self._conversion_layers,
            in_emulation_graph=self._in_emulation_graph,
        )
        for idx, inp_layer in enumerate(conversion_adder.input_layers_to_translate):
            config = InputConversionConfig(
                cmd_type=SupportedCommands.INPUT_CONVERSION,
                conversion_type=conversion_adder.conversion_type,
                conversion_layer_name=conversion_adder.conversion_layers[idx],
                emulate_conversion=self._in_emulation_graph,
            )
            self._meta_data[inp_layer.name] = config

        return conversion_adder.add_conversion_layers()

    def _all_layers(self):
        layers = self._conversion_layers.copy()
        if self._input_layer:
            layers.append(self._input_layer)
        return layers

    def _replace_all_layers(self, new_values):
        count = len(self._conversion_layers)
        self._conversion_layers = new_values[:count]
        if self._input_layer:
            self._input_layer = new_values[count]

    def remove_scope(self):
        self._conversion_layers = self._remove_scope(self._conversion_layers)
        if self._input_layer is not None:
            self._input_layer = self._remove_scope(self._input_layer)

    @classmethod
    def is_in_emulation_graph(cls, conversion_type, tokens):
        in_emulation_graph = conversion_type in ColorConversionType
        if isinstance(tokens.function_args[-1], dict) and "emulator_support" in tokens.function_args[-1]:
            value = next(iter(tokens.function_args[-1].values()))
            if value == "True":
                in_emulation_graph = True
            elif value == "False":
                in_emulation_graph = False
            else:
                raise AllocatorScriptParserException(
                    f"Invalid value {value} given for emulator_support for {conversion_type.value} conversion - please "
                    f"use either True or False. "
                    "See usage guidelines at: Dataflow Compiler User Guide / Building Models / Model Optimization / Model Scripts",
                )

            emulation_not_supported = in_emulation_graph and conversion_type not in EmulationSupportedConversions
            disable_emulation_not_supported = not in_emulation_graph and conversion_type in ColorConversionType
            if emulation_not_supported or disable_emulation_not_supported:
                raise AllocatorScriptParserException(
                    f"The flag emulator_support={value} is not supported by conversion type {conversion_type.value}. "
                    "See usage guidelines at: Dataflow Compiler User Guide / Building Models / Model Optimization / Model Scripts",
                )
        elif conversion_type not in HybridConversionType:
            default_logger().info(
                f"The flag emulator_support was not given for {conversion_type.value} conversion."
                f" using emulator_support={in_emulation_graph}. "
                "See usage guidelines at: Dataflow Compiler User Guide / Building Models / Model Optimization / Model Scripts",
            )

        if conversion_type in HybridConversionType and not in_emulation_graph:
            default_logger().info(
                f"{conversion_type.value} is a hybrid conversion. emulator_support=False is applied for "
                f"{InputConversionLayersAdder.HYBRID_CONVERSION_TO_FORMAT_CONVERSION[conversion_type].value}, "
                f"but {ColorConversionType.yuv_to_rgb.value} will be emulated.",
            )

        return in_emulation_graph
