from typing import Dict, Optional

from pydantic.v1 import BaseModel, Field

from hailo_model_optimization.acceleras.utils.acceleras_definitions import FlowState


class BaseFlowState(BaseModel):
    """
    Represent a flow state of HailoModel. Aggregates information such as is eachy lossy element enabled or disabled.
    """

    full_name: str = Field(description="state name")

    class Config:
        use_enum_values = True


class LossyState(BaseFlowState):
    lossy_class_type: Optional[str] = Field(description="__class__.__name__ of lossy element")
    is_lossless: bool = Field(description="Turn the elementt lossy")
    lossy_dict_kwgs: Dict = Field(
        default={},
        description="args that are relevent only to some types of LossyElements such as 'bits' ",
    )


class AtomicOpState(BaseFlowState):
    aops_class_type: Optional[str] = Field(description="__class__.__name__ of atomic op")
    status: FlowState = Field(
        description="string enum, describe the flow of the AtomicOP. one of {FULLY_NATIVE, NUMERIC, BIT_EXACT}"
    )
    input_lossy_elements: Dict[str, LossyState] = Field(default={}, description="aggregates input LossyElements")
    output_lossy_elements: Dict[str, LossyState] = Field(default={}, description="aggregates output LossyElements")
    weight_lossy_elements: Optional[Dict[str, LossyState]] = Field(
        default={}, description="aggregates weight LossyElements"
    )
    internal_encoding_enabled: Optional[bool] = Field(description="Describe the encoding status of the AtomicOp")
    internal_decoding_enabled: Optional[bool] = Field(description="Describe the decoding status of the AtomicOp")
    quant_inputs_enabled: Optional[bool] = Field(description="Describe the encoding status of the Op")
    aops_dict_kwgs: Dict = Field(
        default={},
        description="args that are relevent only to some types of aops such as aops that have lossless field ",
    )


class LayerState(BaseFlowState):
    atomic_ops: Dict[str, BaseFlowState] = Field(
        default={}, description="aggregates AtomicOps or Layers (in the case of decomposed layer)"
    )
    enforce_internal_encoding_in_call: Optional[bool] = Field(description="Enforce internal encoding in layer call")


class ModelState(BaseFlowState):
    layers: Dict[str, LayerState] = Field(default={}, description="aggregates Layer or blocks of layers")
