from abc import ABC, abstractmethod
from collections import defaultdict, deque
from enum import Enum
from operator import attrgetter

import networkx as nx

from hailo_sdk_client.model_translator.exceptions import (
    MisspellNodeError,
    ParsingWithRecommendationException,
    UnsupportedModelError,
)
from hailo_sdk_client.model_translator.nn_graph import VertexParsingStatus
from hailo_sdk_client.model_translator.translator import HailoNNConverter
from hailo_sdk_common.hailo_nn.hn_layers import InputLayer
from hailo_sdk_common.tools.models_translator_helper import valid_orig_name

INPUT_OP = "input"
INFEASIBLE_END_NODES = ["Split"]


class VertexState(Enum):
    SKIPPED = 0
    CONSUMED = 1
    PROCESSED = 2


class EdgeNNConverter(HailoNNConverter, ABC):
    def __init__(self, graph, start_node_names=None, end_node_names=None):
        super().__init__(graph, start_node_names, end_node_names)
        self._input_vertices = self._get_input_vertices()
        self._recommended_start_names = {node.name for node in self._input_vertices}
        self._recommended_end_names = set()

    def _create_layers(self):
        self._visited_states = {}
        self._add_input_layers()
        self._update_vertices_info()
        self._add_direct_layers()
        self._validate_processed_vertices()

    def _update_vertices_info(self):
        pass

    def _get_input_vertices(self):
        net_input = self._graph.net_input
        wrong_names = []
        if self._start_node_names:
            result = []
            for start_node_name in self._start_node_names:
                valid_start_node_name = valid_orig_name(start_node_name)
                input_node = None
                for node in net_input:
                    if valid_start_node_name == valid_orig_name(node.name):
                        input_node = node
                        result.append(input_node)
                        break
                if not input_node:
                    start_node = self._graph.get_vertex_by_valid_name(valid_start_node_name)
                    if not start_node:
                        wrong_names.append(start_node_name)
                        continue
                    result.extend(start_node.get_start_node_preds())
            if wrong_names:
                err_str = "start node names" if len(wrong_names) > 1 else "start node name"
                raise MisspellNodeError(f"Unable to find {err_str}: {wrong_names}, please verify and try again.")
            return list(set(result))
        return net_input

    def _add_input_layers(self):
        for vertex in self._input_vertices:
            input_shapes = vertex.get_input_layer_shapes()
            for shape in input_shapes:
                rank = len(shape)
                if rank not in [2, 3, 4]:
                    raise UnsupportedModelError(
                        f"Input layer {vertex.name} has an input tensor with {rank} dimensions, which is not supported "
                        "by the Dataflow Compiler. Only 2-4 dimensional tensors are allowed",
                    )
            layer = InputLayer.create(vertex.name, input_shapes)
            self._add_layer(layer, has_edge=False)
            self._vertices_to_layers[vertex] = layer
            self._visited_states[vertex] = VertexState.PROCESSED
            self._graph.add_vertex_by_name(vertex)
        if self._end_node_names is None:
            for input_node in self._graph.net_input:
                input_node.in_valid_subgraph = True

    def _validate_processed_vertices(self):
        if any(x not in [VertexState.PROCESSED, VertexState.CONSUMED] for x in self._visited_states.values()):
            skipped_vertices = [x for x, y in self._visited_states.items() if y == VertexState.SKIPPED]
            skipped_non_outputs = [x.name for x in skipped_vertices if len(list(self._graph.successors(x))) > 0]
            if skipped_non_outputs:
                raise UnsupportedModelError(
                    f"Failed to process entire graph. Might be an unsupported structure due to non-output skipped vertices. "
                    f"Here's a full list of skipped vertices: {skipped_non_outputs}",
                )

    @abstractmethod
    def _should_skip_vertex(self, vertex):
        pass

    @abstractmethod
    def _layer_callback_from_vertex(self, vertex):
        pass

    def _add_direct_layers(self):
        vertex_queue = deque(self._input_vertices.copy())
        vertex_set = set(vertex_queue)
        while vertex_queue:
            vertex = vertex_queue.popleft()
            self._current_vertex = vertex
            if not vertex.in_valid_subgraph:
                continue

            if self._should_skip_vertex(vertex) and vertex not in self._visited_states:
                self._visited_states[vertex] = VertexState.SKIPPED
                self._logger.debug(f"Skipping vertex {vertex.name}")
            elif vertex not in self._visited_states:
                self._logger.debug(f"Processing vertex {vertex.name}")
                self._layer_callback_from_vertex(vertex)
                self._visited_states[vertex] = VertexState.PROCESSED

            for node in sorted(self._graph.successors(vertex), key=attrgetter("name")):
                if (
                    (node not in self._visited_states or self._visited_states[node] == VertexState.CONSUMED)
                    and node not in vertex_set
                    and not node.is_shape_op()
                ):
                    vertex_queue.append(node)
                    vertex_set.add(node)

        for vertex, vertex_state in self._visited_states.items():
            if (
                vertex_state != VertexState.SKIPPED
                and not vertex.is_const()
                and vertex not in self._errors_dict
                and (
                    len(list(self._graph.successors(vertex))) == 0
                    or all(succ.is_shape_op() for succ in self._graph.successors(vertex))
                    or vertex.name in self.end_node_names
                )
            ):
                # no successors / all successors are shape (branch is truncated) - successful end node
                self._successful_end_nodes.add(vertex.name)
            vertex.parsing_status = VertexParsingStatus.PARSED

        suggestions_msg = ""
        if self._errors_dict:
            errors_str = self._generate_suggestions()
            suggestions_msg = f"Parsing failed. The errors found in the graph are:{errors_str}"
        elif set(self._successful_end_nodes) != set(self.end_node_names) and len(self._successful_end_nodes) >= len(
            self.end_node_names
        ):
            # no errors occurred, but there are more or different possible end nodes
            self._recommended_end_names = self._successful_end_nodes
            suggestions_msg = "Mismatch between original end nodes and successful end nodes"
            if len(self._successful_end_nodes) > len(self.end_node_names):
                suggestions_msg = "More possible end nodes found, consider using them for more precise parsing"
        self._remove_original_nodes_suggestions()
        if self._recommended_start_names or self._recommended_end_names:
            raise ParsingWithRecommendationException(
                suggestions_msg,
                recommended_start_node_names=list(self._recommended_start_names),
                recommended_end_node_names=list(self._recommended_end_names),
                parsing_report=self.get_parsing_report(from_error=True),
            )

    def _remove_original_nodes_suggestions(self):
        """
        Clear the suggestions for the following cases:
        * In case all of the start suggestions are the original input layers, we will not recommend them.
        * In case all of the end suggestions the original output layers, we will not recommend them.
        """
        if self._recommended_start_names == {node.name for node in self._input_vertices}:
            self._recommended_start_names.clear()
        if self._recommended_end_names == set(self.end_node_names):
            self._recommended_end_names.clear()

    def _generate_suggestions(self):
        self._set_error_nodes_infeasible_ends()
        errors_str = self._add_start_end_nodes_based_on_errors()
        self._remove_redundant_recommendations()
        # include alternative routes that stem from bottlenecks parts of the graph
        self._find_bottlenecks_with_alternative_end()
        # exclude "bad" end nodes such as feature splitter (has ambiguous output edges)
        self._find_infeasible_alternative_end()
        # refrain from suggesting start/end nodes that are inefficient for hardware inference
        self._prevent_transpose_hw_suggestion()
        # map start to end nodes by finding a possible path between them
        possible_paths_mapping = self._map_possible_paths()
        intersecting_paths = self._find_connected_components_union(possible_paths_mapping)
        independent_possible_paths_mapping, valid_intersecting_paths = self._extract_valid_subgraphs(
            intersecting_paths, possible_paths_mapping
        )
        self._extract_independent_optimal_subgraphs(independent_possible_paths_mapping, valid_intersecting_paths)
        return errors_str

    def _set_error_nodes_infeasible_ends(self):
        """
        update the descendants of error vertices to be impossible end nodes, in case they have a predecessor that is
        of a different branch origin
        """
        # update successors as descendant_of_error for dfs successors
        for err_vertex in self._errors_dict:
            # if vertex is already descendant of error, no need to update its descendants
            # also skip an error vertex if it comes straight after an input vertex, in order to have a proper start
            # node recommendation
            if not err_vertex.is_descendant_of_error_and_optional_end_node or any(
                pred.op in INPUT_OP for pred in self.graph.predecessors(err_vertex)
            ):
                continue
            descendants = nx.descendants(self.graph, err_vertex)
            for succ in descendants:
                is_descendant_of_error_and_optional_end_node = True
                preds = list(self.graph.predecessors(succ))
                for pred in preds:
                    if pred not in descendants:
                        # special case where a vertex is a descendant of an error, but has a predecessor that is a
                        # candidate for end node as its a part of a different branch (e.g. sparseint-Reshape_ 401)
                        self._recommended_end_names.add(pred.name)
                    if not pred.is_descendant_of_error_and_optional_end_node or pred in self._errors_dict:
                        is_descendant_of_error_and_optional_end_node = False
                if not is_descendant_of_error_and_optional_end_node and len(preds) > 1:
                    succ.is_descendant_of_error_and_optional_end_node = False
        # remove descendants of errors from successful end nodes
        self._successful_end_nodes.update(
            [
                node
                for node in self._successful_end_nodes
                if self.graph.get_vertex_by_name(node).is_descendant_of_error_and_optional_end_node
            ]
        )

    def _find_alternative_start_nodes(self, current_start_vertex):
        """
        Perform BFS to find an alternative start node that is not in self._errors_dict
        and is an optional graph start node.
        """
        queue = deque([current_start_vertex])
        visited = set()
        alternative_start_nodes = []

        while queue:
            current_vertex = queue.popleft()
            if current_vertex in visited:
                continue
            visited.add(current_vertex)

            for succ in self.graph.successors(current_vertex):
                if succ not in self._errors_dict and succ.is_optional_graph_start_node():
                    alternative_start_nodes.append(succ.name)
                else:
                    queue.append(succ)

        return alternative_start_nodes

    def _add_start_end_nodes_based_on_errors(self):
        """
        extract recommended start and end nodes based on the errors found in the graph
        """
        errors_str = ""
        for vertex, (error, succs, preds) in self._errors_dict.items():
            self._logger.debug(f"Failed vertex {vertex.name}")
            vertex.parsing_status = VertexParsingStatus.FAILED
            errors_str += f"\n {type(error).__name__} in op {vertex.name}: {error!s}"
            for succ in succs:
                succ_vertex = self.graph.get_vertex_by_name(succ)
                if succ_vertex in self._errors_dict or (not succ_vertex.is_optional_graph_start_node()):
                    self._recommended_start_names.update(self._find_alternative_start_nodes(succ_vertex))
                elif succ_vertex.is_optional_graph_start_node():
                    self._recommended_start_names.add(succ)
            # in case a pred has more than one succ, add them to the recommended start nodes
            for pred in preds:
                pred_succs = self.graph.successors(self.graph.get_vertex_by_name(pred))
                siblings = [
                    sibling.name
                    for sibling in (set(pred_succs) - {vertex})
                    if sibling not in self._errors_dict and not sibling.is_const()
                ]
                self._recommended_start_names.update(siblings)
            self._recommended_end_names.update(
                [
                    pred
                    for pred in preds
                    if pred not in self._errors_dict
                    and self.graph.get_vertex_by_name(pred).is_descendant_of_error_and_optional_end_node
                ]
            )
        self._recommended_end_names.update(self._successful_end_nodes)
        return errors_str

    @abstractmethod
    def _prevent_transpose_hw_suggestion(self, start_names, end_names):
        """
        Prevent inclusion of operators that may harm core efficiency while being easily offloaded to the host
        for example: spatial reshape/transpose operators
        """
        pass

    def _remove_redundant_recommendations(self):
        # remove start nodes that their predecessors are also start recommendations, or start nodes that are outputs ops
        finalized_start_names = set()
        for start_node in self._recommended_start_names:
            cur_start_vertex = self.graph.get_vertex_by_name(start_node)
            start_node_preds = list(self.graph.predecessors(cur_start_vertex))
            if not any(
                node_pred.name in self._recommended_start_names and node_pred not in self._errors_dict
                for node_pred in start_node_preds
            ):
                finalized_start_names.add(start_node)
        self._recommended_start_names.clear()
        self._recommended_start_names.update(finalized_start_names)
        # end nodes
        recommended_end_names_screened = set()
        for cur_end_node in self._recommended_end_names:
            cur_end_vertex = self.graph.get_vertex_by_name(cur_end_node)
            end_node_preds = self.graph.predecessors(cur_end_vertex)
            if (
                not any(node_pred in self._errors_dict for node_pred in end_node_preds)
                and cur_end_vertex not in self._errors_dict
                and cur_end_vertex.is_descendant_of_error_and_optional_end_node
                and not cur_end_vertex.is_shape_op()
            ):
                # add node to recommended end nodes if it is not a descendant of an error
                recommended_end_names_screened.add(cur_end_node)
                # remove pred end name from recommended end names if its successor is now recommended
                recommended_end_names_screened.difference_update(end_node_preds)
        self._recommended_end_names.clear()
        self._recommended_end_names.update(recommended_end_names_screened)

    def _find_infeasible_alternative_end(self):
        """
        Remove infeasible end nodes from the recommended end nodes
        """
        new_alternatives = set()
        finalized_end_names_copy = self._recommended_end_names.copy()
        for vertex_name in self._recommended_end_names:
            vertex = self.graph.get_vertex_by_name(vertex_name)
            if vertex.op in [*INFEASIBLE_END_NODES, INPUT_OP] and not any(
                nx.has_path(self.graph, self.graph.get_vertex_by_name(x), vertex) for x in new_alternatives
            ):
                finalized_end_names_copy.remove(vertex_name)
                if vertex.op in INFEASIBLE_END_NODES:
                    # remove descandants of infeasible end nodes from the recommended end nodes
                    infeasible_descandants = nx.descendants(self.graph, vertex)
                    infeasible_descandants = {
                        end_node
                        for end_node in self._recommended_end_names
                        if self.graph.get_vertex_by_name(end_node) in infeasible_descandants
                    }
                    finalized_end_names_copy.difference_update(infeasible_descandants)
                    # update new alternatives with the infeasible end node preds
                    new_alternatives.add(vertex_name)
                    alternative_preds = [vertex.name for vertex in self.graph.predecessors(vertex)]
                    finalized_end_names_copy.update(alternative_preds)

        self._recommended_end_names = finalized_end_names_copy

    def _find_bottlenecks_with_alternative_end(self):
        bottlenecks_preds = set()
        needles = list(self._errors_dict)
        haystack = [
            x
            for x in self.graph.vertices_by_name.values()
            if not x.is_descendant_of_error_and_optional_end_node and x not in needles
        ]

        if not self._recommended_start_names:
            recommended_start_nodes = [
                node for node, degree in dict(self.graph.in_degree).items() if degree == 0 and not node.is_const()
            ]
        else:
            recommended_start_nodes = [
                self.graph.get_vertex_by_name(start_node_name) for start_node_name in self._recommended_start_names
            ]
        # find optional bottlenecks: descendants of errors which has at least one pred that can be an optional end node
        filtered_haystack = [
            x
            for x in haystack
            if any(
                pred.is_descendant_of_error_and_optional_end_node and pred not in self._errors_dict
                for pred in self.graph.predecessors(x)
            )
        ]

        for error_end_node in filtered_haystack:
            preds = list(self.graph.predecessors(error_end_node))
            good_preds = [
                x.name
                for x in preds
                if x.is_descendant_of_error_and_optional_end_node
                and x not in self._errors_dict
                and not x.is_const()
                and not x.is_shape_op()
                and x.in_valid_subgraph
                and any(nx.has_path(self.graph, start_node, x) for start_node in recommended_start_nodes)
            ]
            bottlenecks_preds.update(good_preds)

        self._recommended_end_names.update(bottlenecks_preds)

    def _find_intersecting_nodes(self, start_vertex, end_vertex):
        """
        Find all nodes that are reachable from both start_vertex and end_vertex
        """
        start_descandants = set(nx.descendants(self.graph, start_vertex))
        end_ancestors = set(nx.ancestors(self.graph, end_vertex))
        reachable_nodes = start_descandants.intersection(end_ancestors)
        reachable_nodes.update({start_vertex, end_vertex})
        return reachable_nodes

    def _is_valid_path(self, path, start_nodes):
        """
        A path is valid if all nodes' predecessors are included in the path.
        * do not validate start node preds
        * exclude const preds which are not part of the path
        """
        truncated_start_path = path.copy()
        truncated_start_path.difference_update(start_nodes)
        for node in truncated_start_path:
            cur_preds = {pred.name: pred for pred in self.graph.predecessors(self.graph.get_vertex_by_name(node))}
            preds_not_in_path = [cur_pred for cur_pred in cur_preds if cur_pred not in path]
            if preds_not_in_path and any(
                missing_pred for missing_pred in preds_not_in_path if not cur_preds[missing_pred].is_const()
            ):
                return False
        return True

    def _map_possible_paths(self):
        """
        Find all possible paths between start nodes and end nodes.
        There can be more than one path per end node, as it may be reached from different start nodes
        """
        map_end_nodes_to_possible_paths = defaultdict(list)
        for start_node_name in self._recommended_start_names:
            for end_node_name in self._recommended_end_names:
                if start_node_name != end_node_name:
                    start_vertex = self.graph.get_vertex_by_name(start_node_name)
                    end_vertex = self.graph.get_vertex_by_name(end_node_name)
                    if nx.has_path(self.graph, start_vertex, end_vertex):
                        path_nodes = self._find_intersecting_nodes(start_vertex, end_vertex)
                        if path_nodes and all(node not in self._errors_dict for node in path_nodes):
                            path_length = len(path_nodes)
                            map_end_nodes_to_possible_paths[end_node_name].append(
                                (start_node_name, path_length, {node.name for node in path_nodes})
                            )

        map_end_nodes_to_paths_sorted_by_len = {
            end_node_name: sorted(paths_list, key=lambda x: x[1], reverse=True)
            for end_node_name, paths_list in map_end_nodes_to_possible_paths.items()
        }
        sorted_end_nodes_by_max_path_len = sorted(
            map_end_nodes_to_paths_sorted_by_len.keys(),
            key=lambda k: map_end_nodes_to_paths_sorted_by_len[k][0][1],
            reverse=True,
        )
        return {
            end_node: map_end_nodes_to_paths_sorted_by_len[end_node] for end_node in sorted_end_nodes_by_max_path_len
        }

    def _is_additional_edge_node_relevant(self, start_node1, start_node2, end_node1, end_node2, compared_path):
        """
        Check if the additional edge node is relevant, as it may have extra output that is not included in the
        compared path.
        * if the start nodes are the same, examine the relevant output of the end node
        * if the end nodes are the same, examine the relevant input of the start node
        """
        if start_node1 == start_node2:
            end_node_succs = self.graph.successors(self.graph.get_vertex_by_name(end_node2))
            return any(end_node_succ for end_node_succ in end_node_succs if end_node_succ.name not in compared_path)
        if end_node1 == end_node2:
            start_node_preds = self.graph.predecessors(self.graph.get_vertex_by_name(start_node2))
            return any(
                start_node_pred for start_node_pred in start_node_preds if start_node_pred.name not in compared_path
            )
        return False

    def _find_connected_components_union(self, possible_paths_mapping):
        """
        Find connected components union between different paths, while calculating the total length of the intersecting paths
        * intersecting paths are paths that share at least one intermediate node (excluding start and end nodes)
        * an end node can have multiple intersecting paths to different start nodes
        """
        intersecting_paths = {}
        processed_pairs = set()
        for end_node1 in possible_paths_mapping:
            for start_node1, path_len1, path1 in possible_paths_mapping[end_node1]:
                for end_node_idx, end_node2 in enumerate(possible_paths_mapping):
                    for path_idx, (start_node2, path_len2, path2) in enumerate(possible_paths_mapping[end_node2]):
                        if path1 == path2 or start_node1 == end_node2 or start_node2 == end_node1:
                            continue

                        processed_pair = frozenset([(start_node1, end_node1), (start_node2, end_node2)])
                        if processed_pair not in processed_pairs:
                            intersection_len = len(set(path1).intersection(path2))
                            # do not add paths that are a complete subsets of each other,
                            # unless the end node has extra relevant output
                            if intersection_len:
                                if (path1 < path2 or path2 < path1) and not self._is_additional_edge_node_relevant(
                                    start_node1, start_node2, end_node1, end_node2, path1
                                ):
                                    continue
                                path1_info = (end_node1, start_node1, path_len1)
                                path2_info = (end_node2, start_node2, path_len2)
                                combined_path_idx = f"{end_node_idx}_{path_idx}"
                                self._update_intersecting_paths(
                                    intersecting_paths,
                                    path1_info,
                                    path2_info,
                                    intersection_len,
                                    combined_path_idx,
                                )
                            processed_pairs.add(processed_pair)
        return intersecting_paths

    @staticmethod
    def _update_intersecting_paths(
        intersecting_paths,
        path1_info,
        path2_info,
        intersection_len,
        path_idx,
    ):
        """
        Add relevant information to the intersecting paths dictionary
        """
        end_node1, start_node1, path_len1 = path1_info
        end_node2, start_node2, path_len2 = path2_info
        if end_node1 not in intersecting_paths:
            intersecting_paths[end_node1] = {"start": start_node1, "total_len": path_len1, "intersecting_paths": {}}

        intersecting_paths[end_node1]["total_len"] += path_len2 - intersection_len
        intersecting_paths[end_node1]["intersecting_paths"].update(
            {
                f"{path_idx}": {
                    "start": start_node2,
                    "end": end_node2,
                },
            }
        )

    def _extract_valid_subgraphs(self, intersecting_paths, possible_paths_mapping):
        """
        for each intersecting paths, we create a joint subgraph that contains all nodes, and determine its validity.
        valid subgraph: a subgraph that all nodes' predecessors (start nodes excluded) are included in the subgraph
        """
        independent_possible_paths_mapping = possible_paths_mapping.copy()
        valid_intersecting_paths = {}
        for end_node in intersecting_paths:
            cur_end_node_info = possible_paths_mapping[end_node]
            accumulating_path = set(
                next(path for path in cur_end_node_info if path[0] == intersecting_paths[end_node]["start"])[2]
            )
            start_nodes = {intersecting_paths[end_node]["start"]}
            cur_intersecting_paths = intersecting_paths[end_node]
            for intersecting_path in cur_intersecting_paths["intersecting_paths"].values():
                start_nodes.update({intersecting_path["start"]})
                cur_end_node_info = possible_paths_mapping[intersecting_path["end"]]
                accumulating_path.update(
                    next(path for path in cur_end_node_info if path[0] == intersecting_path["start"])[2]
                )
            if self._is_valid_path(accumulating_path, start_nodes):
                valid_intersecting_paths[end_node] = cur_intersecting_paths
                independent_possible_paths_mapping.pop(end_node)

        return independent_possible_paths_mapping, valid_intersecting_paths

    def _shared_ancestor_exists(self, vertex1, vertex2):
        """
        Check if two nodes share a common ancestor
        """
        start_ancestors1 = set(nx.ancestors(self.graph, vertex1))
        start_ancestors2 = set(nx.ancestors(self.graph, vertex2))
        return bool(start_ancestors1.intersection(start_ancestors2))

    def _are_paths_independent(self, start_node1, start_node2):
        """
        Check if the paths are independent, meaning no common ancestor, nor a path exists between them
        """
        vertex1 = self.graph.get_vertex_by_name(start_node1)
        vertex2 = self.graph.get_vertex_by_name(start_node2)
        if (
            nx.has_path(self.graph, vertex1, vertex2)
            or nx.has_path(self.graph, vertex2, vertex1)
            or self._shared_ancestor_exists(vertex1, vertex2)
        ):
            return False
        return True

    def _extract_longest_simple_path(self, independent_possible_paths_mapping):
        """
        Extract the longest valid simple path from all possible paths - single start node to single end node
        """
        max_possible_path_len = float("-inf")
        max_possible_path_key = None
        for end_node, tuples_list in independent_possible_paths_mapping.items():
            for start_node, path_len, path in tuples_list:
                if self._is_valid_path(path, {start_node}) and path_len > max_possible_path_len:
                    max_possible_path_len = path_len
                    max_possible_path_key = end_node
        return max_possible_path_key, max_possible_path_len

    def _extract_longest_subgraph(self, valid_intersecting_paths, max_possible_path_len):
        """
        Extract the longest subgraph from the intersecting paths
        """
        if valid_intersecting_paths:
            max_total_len_key = max(valid_intersecting_paths, key=lambda x: valid_intersecting_paths[x]["total_len"])
            max_total_len = valid_intersecting_paths[max_total_len_key]["total_len"]
            if max_total_len > max_possible_path_len:
                self._recommended_start_names.add(valid_intersecting_paths[max_total_len_key]["start"])
                self._recommended_end_names.add(max_total_len_key)
                for intersecting_path in valid_intersecting_paths[max_total_len_key]["intersecting_paths"].values():
                    self._recommended_start_names.add(intersecting_path["start"])
                    self._recommended_end_names.add(intersecting_path["end"])

    def _add_valid_independent_components(self, independent_possible_paths_mapping, valid_intersecting_paths):
        """
        Add valid independent subgraphs and simple paths to the suggestions
        """
        dependent_start_nodes = set()  # set used to avoid validating the same start node multiple times
        # add all possible subgraphs that are independent from already suggested paths
        for end_node in valid_intersecting_paths:
            cur_start_nodes = {valid_intersecting_paths[end_node]["start"]}
            cur_start_nodes.update(
                intersecting_path["start"]
                for intersecting_path in valid_intersecting_paths[end_node]["intersecting_paths"].values()
            )
            if cur_start_nodes.intersection(dependent_start_nodes) or cur_start_nodes.intersection(
                self._recommended_start_names
            ):
                # if any of the current start nodes are already dependent, or already recommended, continue
                continue
            independent_candidate = True
            truncated_recommended_start_nodes = self._recommended_start_names.copy()
            truncated_recommended_start_nodes.difference_update(cur_start_nodes)
            for recommended_start_node in truncated_recommended_start_nodes:
                if not independent_candidate:
                    # if a previous start node is not independent, no need to check the rest
                    break
                for start_node in cur_start_nodes:
                    if start_node in dependent_start_nodes or not self._are_paths_independent(
                        start_node, recommended_start_node
                    ):
                        independent_candidate = False
                        dependent_start_nodes.add(start_node)
                        break
            if truncated_recommended_start_nodes and independent_candidate:
                self._recommended_start_names.add(start_node)
                self._recommended_end_names.add(end_node)
        # add all possible independent paths that are independent from already suggested paths
        for end_node in independent_possible_paths_mapping:
            for start_node, _, path in independent_possible_paths_mapping[end_node]:
                if start_node in dependent_start_nodes or start_node in self._recommended_start_names:
                    continue
                if self._is_valid_path(path, {start_node}):
                    independent_candidate = True
                    truncated_recommended_start_nodes = self._recommended_start_names.copy()
                    truncated_recommended_start_nodes.difference_update({start_node})
                    for recommended_start_node in truncated_recommended_start_nodes:
                        if not self._are_paths_independent(start_node, recommended_start_node):
                            independent_candidate = False
                            dependent_start_nodes.add(start_node)
                            break
                    if truncated_recommended_start_nodes and independent_candidate:
                        self._recommended_start_names.add(start_node)
                        self._recommended_end_names.add(end_node)

    def _extract_independent_optimal_subgraphs(self, independent_possible_paths_mapping, valid_intersecting_paths):
        self._recommended_start_names.clear()
        self._recommended_end_names.clear()
        # find absolute longest path and equivalent key from all possible paths
        max_possible_path_key, max_possible_path_len = self._extract_longest_simple_path(
            independent_possible_paths_mapping
        )
        # if there are intersecting paths, choose the longest combination
        self._extract_longest_subgraph(valid_intersecting_paths, max_possible_path_len)
        # and vice versa (not nx.has_path(x,y) and not nx.has_path(y, x))
        if (
            self._recommended_start_names == set() and self._recommended_end_names == set() and max_possible_path_key
        ):  # no intersecting paths, or intersecting paths are shorter
            self._recommended_start_names.add(independent_possible_paths_mapping[max_possible_path_key][0][0])
            self._recommended_end_names.add(max_possible_path_key)
        # add all possible paths that do not have a path from their start node to other start nodes in the suggestions
        self._add_valid_independent_components(independent_possible_paths_mapping, valid_intersecting_paths)

    def _handle_recordable_parser_error(self, vertex, e):
        self._logger.debug(f"An error was encountered while processing vertex {vertex}")
        recommended_succs, recommended_preds = [], []
        recommended_succs += [
            succ.name
            for succ in vertex.graph.successors(vertex)
            if succ not in self._errors_dict and not succ.is_shape_op()
        ]
        for pred in vertex.graph.predecessors(vertex):
            if pred.op == INPUT_OP:
                continue
            if pred not in self._errors_dict:
                current_preds = list(vertex.graph.predecessors(pred))
                # current pred is illegal, search for a legal one
                while self._should_skip_vertex(pred) and len(current_preds) == 1 and pred not in self._errors_dict:
                    pred = current_preds[0]
                    current_preds = list(vertex.graph.predecessors(pred))
                if not self._should_skip_vertex(pred) and pred not in self._errors_dict and not pred.is_shape_op():
                    recommended_preds.append(pred.name)

        self._errors_dict[vertex] = (e, recommended_succs, recommended_preds)

    def _update_consumed_vertices_states(self, consumed_vertices, should_assign_vertex_to_layer=True):
        for vertex in consumed_vertices:
            if should_assign_vertex_to_layer and not vertex.is_const():
                self._vertices_to_layers[vertex] = self._current_layer

            if (
                vertex.name not in self._visited_states
                or self._visited_states[vertex].value < VertexState.CONSUMED.value
            ):
                self._visited_states[vertex] = VertexState.CONSUMED

    def _consume_flatten_chain(self, pred, layer):
        pass

    def _consume_pre_layer_op(self, pred, layer):
        self._vertices_with_edges.append(pred)
        self._visited_states[pred] = VertexState.CONSUMED
        self._vertices_to_layers[pred] = layer
        layer.input_vertex_order = pred.input

    def _handle_consumed_vertices(self, consumed_vertices, should_assign_vertex_to_layer=True):
        if consumed_vertices:
            self._update_consumed_vertices_states(consumed_vertices, should_assign_vertex_to_layer)

        # handle possible pre-layer ops that need to be consumed
        vertex = self._current_vertex
        layer = self._vertices_to_layers[vertex]
        for pred in self._graph.predecessors(vertex):
            if (
                pred.is_pre_layer_op()
                and (pred not in self._vertices_to_layers or self._vertices_to_layers[pred] == layer)
                and self._should_skip_vertex(pred)
            ):
                # Add predecessor to vertices with edges to maintain connection to previous layer
                self._vertices_with_edges.append(pred)
                self._visited_states[pred] = VertexState.CONSUMED

                # Change input vertex order to the inputs of first vertex consumed for the layer
                if len(list(self._graph.successors(pred))) == 1:
                    self._vertices_to_layers[pred] = layer
                    layer.input_vertex_order = [
                        pred.input[0] if pred.name in inp else inp for inp in layer.input_vertex_order
                    ]

                    # In case multiple flattens are chained
                    self._consume_flatten_chain(pred, layer)

                # If pre layer op has multiple successors, we link it to it's predecessor
                else:
                    layer_preds = [x for x in self._graph.predecessors(pred) if x in self._vertices_to_layers]
                    if len(layer_preds) == 1:
                        self._vertices_to_layers[pred] = self._vertices_to_layers[layer_preds[0]]
                    else:
                        raise UnsupportedModelError(f"Could not link node {pred.name} to a layer in the graph")
