from dataclasses import dataclass
from typing import Callable, List, Optional

from hailo_model_optimization.acceleras.encoding.encoding_data import EncodingNode


def default_state(inputs: List[EncodingNode], outputs: List[EncodingNode]) -> List[bool]:
    return [False] * len(outputs)


def default_status(inputs: List[EncodingNode], outputs: List[EncodingNode]) -> bool:
    return True


@dataclass
class EncodingConstraint:
    func: Callable
    """The main function this constraint represent."""

    state_function: Callable[[List[EncodingNode], List[EncodingNode]], List[bool]] = default_state
    """A function that given a list of inputs `EncodingNode` and a list of outputs `EncodingNode` updates
    the outputs states.

    Args:
        inputs (List[EncodingNode]): A list of `EncodingNode` with the same length as this constraint inputs.
        outputs (List[EncodingNode]): A list of `EncodingNode` with the same length as this constraint outputs.

    Returns:
        List[bool]: A list with the same length as `outputs`. Each element represent whether the function changed
        the respected output `EncodingNode`.
    """

    status_function: Callable[[List[EncodingNode], List[EncodingNode]], bool] = default_status
    """A function that given a list of inputs `EncodingNode` and a list of outputs `EncodingNode` returns whether
    This constraint is valid.

    Args:
        inputs (List[EncodingNode]): A list of `EncodingNode` with the same length as this constraint inputs.
        outputs (List[EncodingNode]): A list of `EncodingNode` with the same length as this constraint outputs.

    Returns:
        bool: True if this constraint is valid.
    """

    func_string: Optional[str] = None
    """A string template representing this constraint function."""

    def format(self, *args, func_name="func"):
        if self.func_string is None:
            return f'{func_name}({", ".join(args)})'
        return f"{self.func_string.format(*args)}"
