"""Module with data classes for the SDK backend"""

from typing import Optional

from pydantic.v1 import BaseModel, Field
from tensorflow.data import Dataset

from hailo_model_optimization.acceleras.utils.acceleras_definitions import DistributionStrategy
from hailo_model_optimization.acceleras.utils.flow_state.updater import FlowCommands
from hailo_model_optimization.flows.optimization_flow import SupportedStops
from hailo_model_optimization.tools.orchestator import FlowCheckPoint
from hailo_sdk_client.exposed_definitions import InferenceContext
from hailo_sdk_common.export.hailo_graph_export import GraphExport


class InfoConfig(BaseModel):
    class Config:
        arbitrary_types_allowed = True
        validate_assignment = True
        extra = "forbid"


class InternalContextInfo(InfoConfig):
    infer_context: InferenceContext
    open: bool
    graph_export: Optional[GraphExport]
    gpu_policy: DistributionStrategy
    flow_commands: FlowCommands = Field(
        "", description="Path to a file that can modify the model inference, after model Optimization"
    )
    lora_adapter_name: Optional[str] = None


class InferInfo(InfoConfig):
    context_info: InternalContextInfo
    data: Dataset
    batch_size: int
    data_count: Optional[int]
    run_eagerly: bool = True


class CheckpointInfo(InfoConfig):
    run_until: SupportedStops = Field(SupportedStops.NONE, description="until when to run the flow")
    flow_memento: Optional[FlowCheckPoint] = Field(None, description="Flow memento to continue from")
    quantization_done: bool = False
