from typing import Dict, List, Literal, Tuple, Union

from pydantic.v1 import BaseModel, Field, validator

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    FormatConversionType,
    InputConversions,
    IOType,
    PostprocessTarget,
)
from hailo_model_optimization.acceleras.utils.modification_tracker.modification_tracker_utils import (
    AlgoModificationTracker,
)
from hailo_sdk_client.sdk_backend.script_parser.commands import SupportedCommands
from hailo_sdk_common.hailo_nn.hn_definitions import (
    ActivationType,
    ColorConversionType,
    HWLayerType,
    HybridConversionType,
    LayerType,
    NMSMetaArchitectures,
    ResizeBilinearPixelsMode,
    ResizeMethod,
)


class BaseCommandConfig(BaseModel):
    cmd_type: str

    class Config:
        extra = "forbid"


class NormalizationConfig(BaseCommandConfig):
    cmd_type: Literal[SupportedCommands.NORMALIZATION]
    mean: List[float]
    std: List[float]
    normalization_layer: str


class TransposeConfig(BaseCommandConfig):
    cmd_type: Literal[SupportedCommands.TRANSPOSE]


class ResizeConfig(BaseCommandConfig):
    cmd_type: Literal[SupportedCommands.RESIZE]
    input_shape: List[int]
    output_shape: List[int]
    pixels_mode: ResizeBilinearPixelsMode
    hw_layer_type: HWLayerType
    resize_layer_name: str
    interpolation_method: ResizeMethod


class InputConversionConfig(BaseCommandConfig):
    cmd_type: Literal[SupportedCommands.INPUT_CONVERSION]
    conversion_type: Union[FormatConversionType, ColorConversionType, HybridConversionType]
    conversion_layer_name: str
    emulate_conversion: bool

    @validator("conversion_type")
    def validate_conversion_type(cls, v):
        if isinstance(v, FormatConversionType) and v not in InputConversions:
            raise ValueError(f"{v.value} is not a valid input conversion.")
        return v


class ChangeOutputActivationConfig(BaseCommandConfig):
    cmd_type: Literal[SupportedCommands.CHANGE_OUTPUT_ACTIVATION]
    original_activation: ActivationType
    new_activation: ActivationType
    hn_layer_name: str


class LogitsLayerConfig(BaseCommandConfig):
    cmd_type: Literal[SupportedCommands.LOGITS_LAYER]
    logit_layer_name: str
    logits_type: Literal[LayerType.softmax, LayerType.argmax]
    axis: int
    engine: PostprocessTarget


class NMSConfig(BaseCommandConfig):
    cmd_type: Literal[SupportedCommands.NMS_POSTPROCESS]
    engine: PostprocessTarget
    meta_arch: NMSMetaArchitectures
    hn_output_layers: List[str]
    sigmoid_layers: List[str]


class SetKVCachePairsConfig(BaseCommandConfig):
    cmd_type: Literal[SupportedCommands.SET_KV_CACHE_PAIR]
    pair_names: Tuple
    cache_id: int
    cache_type: IOType


class SetKVCacheGlobalParamsConfig(BaseCommandConfig):
    cmd_type: Literal[SupportedCommands.SET_KV_CACHE_GLOBAL_PARAMS]
    prefill_size: int
    cache_size: int


command_types = Union[
    NormalizationConfig,
    TransposeConfig,
    NMSConfig,
    ResizeConfig,
    InputConversionConfig,
    ChangeOutputActivationConfig,
    LogitsLayerConfig,
    SetKVCachePairsConfig,
]


class ModificationsConfig(BaseModel):
    inputs: Dict[str, List[command_types]] = Field({})
    outputs: Dict[str, List[command_types]] = Field({})
    tracker: AlgoModificationTracker = Field(AlgoModificationTracker())

    class Config:
        extra = "forbid"
