#!/usr/bin/env python
import os
from enum import Enum
from operator import attrgetter

import networkx as nx

from hailo_sdk_client.model_translator.exceptions import UnexpectedNodeError
from hailo_sdk_client.model_translator.parsing_report import NetParamsReport, ParsingReport
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.paths_manager.paths import SDKPaths
from hailo_sdk_common.tools.models_translator_helper import valid_orig_name


class VertexParsingStatus(Enum):
    PARSED = "parsed"
    FAILED = "failed"
    OUT_OF_SCOPE = "out_of_scope"


class NNGraphNode:
    next_hash = 0

    def __init__(self, node_proto, graph):
        self.name = None
        self.op = None
        self.input = None
        self._info = node_proto
        self._hash = type(self).next_hash
        self._in_valid_subgraph = False
        self._graph = graph
        self._parsing_status = VertexParsingStatus.OUT_OF_SCOPE
        self._is_descendant_of_error_and_optional_end_node = True
        self._output_format = None
        NNGraphNode.next_hash += 1

    def __hash__(self):
        return self._hash

    def __str__(self):
        return self.name

    @property
    def graph(self):
        return self._graph

    @property
    def in_valid_subgraph(self):
        return self._in_valid_subgraph

    @in_valid_subgraph.setter
    def in_valid_subgraph(self, val):
        default_logger().debug(f"Marked vertex {self.name} in_valid_subgraph={val}")
        self._in_valid_subgraph = val

    @property
    def parsing_status(self):
        return self._parsing_status

    @parsing_status.setter
    def parsing_status(self, parsing_status):
        self._parsing_status = parsing_status

    @property
    def is_descendant_of_error_and_optional_end_node(self):
        return self._is_descendant_of_error_and_optional_end_node

    @is_descendant_of_error_and_optional_end_node.setter
    def is_descendant_of_error_and_optional_end_node(self, is_descendant_of_error_and_optional_end_node):
        self._is_descendant_of_error_and_optional_end_node = is_descendant_of_error_and_optional_end_node

    @property
    def output_format(self):
        return self._output_format

    @output_format.setter
    def output_format(self, output_format):
        self._output_format = output_format

    @property
    def input_formats(self):
        return [pred.output_format for pred in list(self.graph.predecessors(self))]

    @property
    def input_format(self):
        return self.input_formats[0] if self.input_formats and None not in self.input_formats else None

    def get_input_shapes(self):
        raise NotImplementedError

    def get_output_shapes(self, **kwargs):
        raise NotImplementedError

    def get_original_info_to_json(self):
        return str(self._info)

    def to_json(self):
        res = {
            "type": self.op,
            "input_shapes": self.get_input_shapes(),
            "output_shapes": self.get_output_shapes(),
            "params": {
                "original_info": self.get_original_info_to_json(),
                "parsing_status": self.parsing_status.value,
            },
        }
        return res

    def update_output_format(self, pred):
        pass


class NNGraph(nx.DiGraph):
    def __init__(self, raw_graph_proto, values):
        super().__init__()
        self._raw_proto = raw_graph_proto
        self._values = values
        self._vertices_by_name = {}
        self._vertices_by_valid_name = {}

    @property
    def values(self):
        return self._values

    @property
    def vertices_by_name(self):
        return self._vertices_by_name

    def nodes_toposorted(self, valid_nodes_only=True):
        graph_copy = nx.DiGraph()
        graph_copy.add_nodes_from(self.nodes)
        graph_copy.add_edges_from(self.edges)
        nodes_to_remove = [node for node in graph_copy.nodes if valid_nodes_only and not node.in_valid_subgraph]
        for node in nodes_to_remove:
            graph_copy.remove_node(node)
        try:
            # If no exception was raised in find_cycle, then the graph contain a cycle,
            # and the fallback is name sorting of the graph nodes
            nx.algorithms.cycles.find_cycle(graph_copy)
            return sorted(graph_copy.nodes, key=attrgetter("name"))
        except nx.exception.NetworkXNoCycle:
            return list(nx.algorithms.dag.lexicographical_topological_sort(graph_copy, key=attrgetter("name")))

    def get_vertex_by_name(self, name):
        if name not in self._vertices_by_name:
            raise UnexpectedNodeError(f"Can't find vertex {name} in graph")
        return self._vertices_by_name[name]

    def get_vertex_by_valid_name(self, name):
        return self._vertices_by_valid_name.get(name)

    def add_vertex_by_name(self, vertex):
        self._vertices_by_name[vertex.name] = vertex
        self._vertices_by_valid_name[valid_orig_name(vertex.name)] = vertex

    def visualize(self, filename_prefix):
        if not SDKPaths().has_graphviz:
            default_logger().warning("Cannot visualize NN graph because graphviz package is unavailable")
            return
        dot_path = f"{filename_prefix}.dot"
        svg_path = "{filename_prefix}.svg"
        nx.drawing.nx_agraph.write_dot(self, dot_path)
        os.system(f'dot -Tsvg "{dot_path}" -o "{svg_path}"')

    def get_parsing_report(self, start_nodes, end_nodes, blocks, meta_arch):
        net_params = NetParamsReport(
            start_node_names=start_nodes,
            end_node_names=end_nodes,
            detected_post_process=meta_arch,
        )
        layers = {}
        out_format = {}
        for current_layer in self.nodes_toposorted(valid_nodes_only=False):
            layer_json = current_layer.to_json()
            layer_json["input"] = [valid_orig_name(str(layer.name)) for layer in self.predecessors(current_layer)]
            layer_json["output"] = [valid_orig_name(str(layer.name)) for layer in self.successors(current_layer)]
            valid_curr_name = valid_orig_name(str(current_layer.name))
            layers[valid_curr_name] = layer_json
            if current_layer.output_format is not None:
                out_format[valid_curr_name] = current_layer.output_format

        name = "parsing_report"
        return ParsingReport(name=name, net_params=net_params, layers=layers, blocks=blocks, output_format=out_format)
