#!/usr/bin/env python

from hailo_sdk_client.model_translator.exceptions import UnexpectedNodeError
from hailo_sdk_client.sdk_backend.sdk_backend_exceptions import BackendNotImplementedError


class ChainNode:
    class Direction:
        BWD = "bwd"
        FWD = "fwd"

    def __init__(self, op, name=None):
        self._op = op
        self._name = name

    def __str__(self):
        return f"ChainNode (op={self._op!s}, name={self._name!s}, dir={self._direction!s})"

    @property
    def direction(self):
        raise BackendNotImplementedError("Function not implemented on SDK backend")

    def does_match(self, real_node, exact_match=False):
        if (exact_match and real_node.op == self._op) or (not exact_match and real_node.op.startswith(self._op)):
            if (self._name is None) or test_scope(real_node.name, self._name):
                return True
        return False


class BwdChainNode(ChainNode):
    @property
    def direction(self):
        return ChainNode.Direction.BWD


class FwdChainNode(ChainNode):
    @property
    def direction(self):
        return ChainNode.Direction.FWD


def test_scope(vertex_name, needle):
    hay_stack = vertex_name.split("/")
    return any(needle in x for x in hay_stack)


def look_and_validate(graph, node, ops_chain, exception=None):
    searched_node = look_for_node(graph, node, ops_chain)
    if searched_node is None:
        raise UnexpectedNodeError(node.name) if exception is None else exception
    return searched_node


def _look_for_node(graph, node, ops_chain, all_nodes, exact_match=False):
    is_back = ops_chain[0].direction == ChainNode.Direction.BWD
    next_nodes = graph.predecessors if is_back else graph.successors
    for next_node in next_nodes(node):
        if ops_chain[0].does_match(next_node, exact_match):
            if len(ops_chain) == 1:
                all_nodes.append(next_node)
                return all_nodes
            recursion_result = _look_for_node(graph, next_node, ops_chain[1:], all_nodes, exact_match=exact_match)
            if recursion_result is not None:
                all_nodes.append(next_node)
                return all_nodes
            # if recursion result is None, continue to next iteration

    # just to make clear what happens if we don't find anything:
    return None


def look_for_node(graph, node, ops_chain, exact_match=False):
    result = []
    _look_for_node(graph, node, ops_chain, result, exact_match)
    if result:
        return result[0]


def get_node_from_possible_chains(graph, node, possible_chains, exact_match=False):
    for chain in possible_chains:
        result = look_for_node(graph, node, chain, exact_match)
        if result is not None:
            return result
    return None


def get_all_nodes_in_chain(graph, node, ops_chain, exact_match=False):
    result = []
    _look_for_node(graph, node, ops_chain, result, exact_match=exact_match)
    if result:
        return result[::-1]


def get_all_nodes_from_possible_chains(graph, node, possible_chains, exact_match=False):
    for ops_chain in possible_chains:
        all_nodes_in_chain = get_all_nodes_in_chain(graph, node, ops_chain, exact_match)
        if all_nodes_in_chain:
            return all_nodes_in_chain
    return None
