from typing import Optional

from hailo_model_optimization.acceleras.model.hailo_model.model_flow import ModelFlow
from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerType
from hailo_model_optimization.saitama.framework.model.model import SModel
from hailo_model_optimization.saitama.translators.base_translator import BaseLayerTranslator, BaseModelTranslator
from hailo_model_optimization.saitama.translators.hailo_translator.hn_layers_translators import (
    HN_TRANSLATION_REGISTRY,
    LayerMode,
)


class HailoTranslator(BaseModelTranslator):
    translator_registry = HN_TRANSLATION_REGISTRY

    def translate(
        self,
        model_hn: dict,
        acceleras_params,
        *,
        layer_mode: Optional[LayerMode] = None,
        dtype=None,
        device=None,
    ) -> SModel:
        self.logger.info("Translating model")
        layers = {}
        unsupported_layers = set()
        for layer_name, hn_layer in model_hn["layers"].items():
            acceleras_params_layer = acceleras_params.get(layer_name, {})
            layer_type = LayerType(hn_layer["type"])
            if layer_type not in self.translator_registry:
                unsupported_layers.add(layer_type)
                continue
            translator: BaseLayerTranslator = self.translator_registry[layer_type](self.logger)
            layer = translator.translate(
                layer_name, hn_layer, acceleras_params_layer, layer_mode=layer_mode, dtype=dtype, device=device
            )
            layers[layer_name] = layer
        if unsupported_layers:
            self.logger.error(f"Unsupported layers: {unsupported_layers}")
            raise NotImplementedError(f"Unsupported layers: {unsupported_layers}")
        model_flow = self._build_model_flow(model_hn)

        model = SModel(layers, model_flow)
        return model

    def _build_model_flow(self, model_hn: dict) -> ModelFlow:
        out_layers_order = model_hn.get("net_params", dict()).get("output_layers_order")
        model_flow = ModelFlow.from_hn_layers(model_hn["layers"], out_layers_order)
        return model_flow
