from abc import ABC, abstractmethod
from logging import Logger
from typing import Union

import torch.nn as nn

from hailo_model_optimization.saitama.framework.common.saitama_module import SaitamaModule
from hailo_model_optimization.saitama.framework.model.model import SModel
from hailo_model_optimization.saitama.translators.translator_utils import TranslatorUtils


class BaseModelTranslator(ABC):
    def __init__(self, logger: Logger = None):
        self.logger = logger
        self.utils = TranslatorUtils()

    @abstractmethod
    def translate(self, *, dtype=None, device=None) -> SModel:
        pass


class BaseLayerTranslator(BaseModelTranslator):
    @abstractmethod
    def translate(self, *, dtype=None, device=None) -> Union[nn.Module, SaitamaModule]:
        pass
