from typing import List

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow


class LayerDecomposeFlow(LayerFlow):
    """
    The representation of the graph connectivity of the hailo_model

    """

    def add_node(self, op: BaseHailoLayer):
        super().add_node(op=op)

    def get_op(self, op_name) -> BaseHailoLayer:
        ops_dict = self._get_op_attribute()
        return ops_dict[op_name]

    def _get_node(self, node):
        if isinstance(node, BaseHailoLayer):
            node = node.full_name
        return node

    def get_ops(self) -> List[BaseHailoLayer]:
        ops = self._get_op_attribute()
        return [ops[op_name] for op_name in self.toposort_ops()]
