#!/usr/bin/env python
from abc import ABC, abstractmethod
from collections import OrderedDict

import numpy as np

from hailo_sdk_client.model_translator.exceptions import (
    MisspellNodeError,
    UnsupportedModelError,
    UnsupportedWeightsError,
)
from hailo_sdk_common.hailo_nn.hailo_nn import HailoNN
from hailo_sdk_common.hailo_nn.hn_definitions import HnStage, LayerType
from hailo_sdk_common.hailo_nn.hn_layers import (
    ConcatLayer,
    EWDivLayer,
    EWMultLayer,
    OutputLayer,
    ShortcutLayer,
)
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.tools.models_translator_helper import valid_orig_name


class HailoNNBaseConverter(ABC):
    def __init__(self, graph, end_node_names):
        self._graph = graph
        self._end_node_names = end_node_names

    @property
    def graph(self):
        return self._graph

    @abstractmethod
    def convert_model(self):
        pass


class HailoNNConverter(HailoNNBaseConverter, ABC):
    def __init__(self, graph, start_node_names=None, end_node_names=None):
        super().__init__(graph, end_node_names)
        self._logger = default_logger()
        self._vertex_index = 0
        self._vertices_queue = []
        self._current_vertex = None
        self._start_node_names = start_node_names
        self._mode = None
        self._sub_mode = None
        self._layers_graph = HailoNN(stage=HnStage.PRE_FUSED.value, end_nodes_order=end_node_names)
        self._current_layer = None
        self._vertices_to_layers = {}
        self._vertices_with_edges = []
        self._errors_dict = {}
        self._calculate_valid_subgraph_scope()
        self._successful_end_nodes = set()
        self._nn_framework = None

    @property
    def end_node_names(self):
        return self._end_node_names

    def _get_node_names_for_report(self, is_start):
        hailo_nn = self._layers_graph
        real_io_layers = hailo_nn.get_real_input_layers() if is_start else hailo_nn.get_real_output_layers()
        orig_idx = 0 if is_start else -1  # first original name for start nodes and last for end nodes
        return [x.original_names[orig_idx] for x in real_io_layers]

    def get_parsing_report(self, from_error=False):
        try:
            meta_arch = None
            if self._layers_graph.detected_anchors:
                meta_arch = self._layers_graph.detected_anchors["meta_arch"]
            start_names = None if from_error else self._get_node_names_for_report(is_start=True)
            end_names = None if from_error else self._get_node_names_for_report(is_start=False)
            return self.graph.get_parsing_report(start_names, end_names, self._layers_graph.blocks, meta_arch)
        except Exception as err:
            self._logger.info(f"Unable to export parsing report: {err!s}")

    def convert_model(self):
        self._validate_model_params()
        self._validate_bn_ops_in_training()
        self._create_layers()
        self._add_layers_connections()
        self._layers_graph.set_names_and_indices()
        self._update_input_indices()
        self._update_output_indices()
        self._handle_multi_output_copies_split()
        self._handle_duplicate_inputs_concat()
        self._handle_inner_product_matmul()
        self._handle_tokens_matmul()
        self._calculate_shapes(validate_shapes=False)
        self._add_output_layers()
        self._handle_fused_layers()
        self._replace_feature_splitter_with_slices()
        self._layers_graph.remove_null_layers()
        self._layers_graph.update_output_indices()
        self._validate_nan_values_in_params()
        self._detect_nms_anchors()
        return self._layers_graph

    def _detect_nms_anchors(self):
        pass

    def _replace_feature_splitter_with_slices(self):
        pass

    def _handle_multi_output_copies_split(self):
        pass

    def _handle_fused_layers(self):
        pass

    def _handle_inner_product_matmul(self):
        pass

    def _handle_tokens_matmul(self):
        pass

    def _validate_model_params(self):
        pass

    def _validate_bn_ops_in_training(self):
        if self._has_bn_ops_in_training():
            self._logger.warning(
                "This model has Batch normalization layers exported in training mode, expect "
                "different inference results between native emulation and original session. "
                "It's recommended to run evaluation on the full-precision model before "
                "optimization to verify accuracy is still good enough.",
            )

    def _has_bn_ops_in_training(self):
        return False

    def _create_layer(self, type, index, name, inputs, input_indices, outputs):
        layer = type()
        layer.index = index
        layer.name = name
        layer.inputs = inputs
        layer.input_indices = input_indices
        layer.outputs = outputs
        self._layers_graph.add_node(layer)
        return layer

    def _handle_tiled_concat(self, layer: ConcatLayer, idx):
        orig_succs = list(self._layers_graph.successors(layer))
        shortcut_name = f"{layer.name}_tile_shortcut"
        concat_name = f"{layer.name}_tile_concat"
        shortcut = self._create_layer(ShortcutLayer, idx, shortcut_name, [layer.name], [layer.index], [concat_name])
        concat = self._create_layer(
            ConcatLayer,
            idx + 1,
            concat_name,
            [layer.name, shortcut_name],
            [layer.index, idx],
            layer.outputs.copy(),
        )
        concat.axis = layer.axis
        concat.original_names = layer.original_names.copy()
        concat.input_list = [layer, shortcut]

        layer.original_names = []
        layer.input_vertex_order = layer.input_vertex_order[: len(layer.input_vertex_order) // 2]
        layer.outputs = [concat.name, shortcut.name]

        self._layers_graph.add_edge(layer, shortcut)
        self._layers_graph.add_edge(shortcut, concat)
        self._layers_graph.add_edge(layer, concat)

        for succ in orig_succs:
            self._layers_graph.remove_edge(layer, succ)
            self._layers_graph.add_edge(concat, succ)
            succ.replace_input_layer(layer.name, concat.name)
            succ.replace_input_index(layer.index, concat.index)

    def _handle_duplicate_inputs_concat(self):
        for layer in list(self._layers_graph):
            if layer.op == LayerType.concat:
                last_idx = len(layer.input_vertex_order)
                input_vertices = layer.input_vertex_order[:last_idx]
                if len(layer.input_list) >= len(input_vertices):
                    continue

                next_idx = self._layers_graph.get_next_index()
                tiled_xx_cond = (
                    input_vertices[0 : len(input_vertices) // 2] == input_vertices[len(input_vertices) // 2 :]
                )
                if len(layer.input_list) > 1 and tiled_xx_cond:
                    self._handle_tiled_concat(layer, next_idx)
                    continue

                original_names = [x.original_names[-1] for x in layer.input_list]
                orig_to_idx = OrderedDict()
                for orig in original_names:
                    orig_to_idx[orig] = [
                        i for i, name in enumerate(layer.input_vertex_order) if name.split(":")[0] == orig.split(":")[0]
                    ]

                if sum(len(indices) for indices in orig_to_idx.values()) != len(layer.input_vertex_order):
                    raise UnsupportedModelError(f"Could not detect inputs to concat layer {layer.full_name_msg}")

                input_list_idx_shift = 0
                for in_vertex_name, relevant_indices in orig_to_idx.items():
                    # add shortcut for each copy of the concat input
                    input_list_idx = original_names.index(in_vertex_name) + input_list_idx_shift
                    concat_pred = layer.input_list[input_list_idx]
                    for idx in relevant_indices[1:]:
                        name = f"{layer.name}_shortcut{idx}"
                        shortcut = self._create_layer(
                            ShortcutLayer,
                            next_idx,
                            name,
                            [concat_pred.name],
                            [concat_pred.index],
                            [layer.name],
                        )
                        next_idx += 1
                        if layer.input_list[idx:]:
                            input_list_idx_shift += 1
                        layer.input_list = layer.input_list[:idx] + [shortcut] + layer.input_list[idx:]
                        layer.inputs = layer.inputs[:idx] + [shortcut.name] + layer.inputs[idx:]
                        layer.input_indices = layer.input_indices[:idx] + [shortcut.index] + layer.input_indices[idx:]
                        concat_pred.outputs.append(shortcut.name)
                        self._layers_graph.add_node(shortcut)
                        self._layers_graph.add_edge(concat_pred, shortcut)
                        self._layers_graph.add_edge(shortcut, layer)
                        self._logger.debug(f"Added shortcut layer {shortcut.name} before concat layer {layer.name}")

    @abstractmethod
    def _create_layers(self):
        pass

    def _calculate_shapes(self, validate_shapes=True):
        self._layers_graph.calculate_shapes(validate_shapes=validate_shapes)

    def _add_output_layers(self, fused_activations=None):
        edges_to_add = []
        fused_activations_orig_names = None if fused_activations is None else [x.name for x in fused_activations]
        for layer in list(self._layers_graph):
            has_output_layer = False
            if layer.original_names and len(layer.original_names) > 0:
                orig_vertex = self.graph.get_vertex_by_valid_name(layer.original_names[0])
                if orig_vertex and hasattr(orig_vertex, "output_tensors_indices"):
                    output_indices = orig_vertex.output_tensors_indices
                    has_output_layer = any(idx in self._graph._model_graph.OutputsAsNumpy() for idx in output_indices)

            is_original_output = self._end_node_names is not None and any(
                orig_name in self._end_node_names for orig_name in layer.original_names
            )

            # edge case for tflite - fused activations non-existing during graph creation
            has_fused_activation = fused_activations_orig_names is not None and any(
                orig_name in fused_activations_orig_names for orig_name in layer.original_names
            )

            if (
                self._layers_graph.out_degree(layer) == 0
                or (is_original_output and not has_fused_activation)
                or has_output_layer
            ):
                try:
                    out_layer = OutputLayer.from_data("out", layer.output_shape)
                except UnsupportedModelError as e:
                    raise UnsupportedModelError(f"Error occurred in layer {layer.full_name_msg}: {e.client_message}")

                out_layer.input_vertex_order = [layer.original_names[0]]
                edges_to_add.append((layer, out_layer))

        added_layers = []
        for edge in edges_to_add:
            self._layers_graph.add_node(edge[1])
            self._layers_graph.add_edge(edge[0], edge[1])
            try:
                edge[1].add_input_by_vertex(edge[0], input_name=edge[0].original_names[0])
            except UnsupportedModelError as e:
                raise UnsupportedModelError(
                    f"Error occurred in layer {edge[1].name} (translated from "
                    f"{edge[1].original_names[-1]}): {e.client_message}",
                )
            added_layers.append(edge[1])
            self._logger.debug(f"Added output layer {edge[1]}")

        if added_layers:
            self._layers_graph.set_names_and_indices()
            for layer in added_layers:
                self._update_layer_input_indices(layer)

    def _add_layer(self, layer, has_edge=True):
        if layer is not None:
            self._layers_graph.add_node(layer)
            self._current_layer = layer
            if has_edge:
                self._vertices_with_edges.append(self._current_vertex)
            self._logger.debug(f"Added new layer, op={self._current_layer.op}, vertex={self._current_vertex}")
        else:
            self._current_layer = None

    def _add_layer_connections(self, vertex):
        layer = self._vertices_to_layers[vertex]
        for prev_vertex, prev_layer in self._predecessors(vertex):
            self._logger.debug(f"Added new edge between vertices {prev_vertex}->{vertex}")
            self._layers_graph.add_edge(prev_layer, layer)
            try:
                layer.add_input_by_vertex(prev_layer, input_vertex=prev_vertex)
            except UnsupportedModelError as e:
                raise UnsupportedModelError(
                    f"Error occurred in layer {layer.name} (translated from {vertex.name}): {e.client_message}",
                )

        if isinstance(layer, (ConcatLayer, EWDivLayer, EWMultLayer)):
            for inp in layer.inputs_by_vertex_order:
                layer.append_to_input_list(inp)

    def _add_layers_connections(self):
        for vertex in self._vertices_with_edges:
            self._add_layer_connections(vertex)

    def _update_input_indices(self):
        for layer in list(self._layers_graph):
            self._update_layer_input_indices(layer)

    @staticmethod
    def _update_layer_input_indices(layer):
        for prev_layer in layer.inputs_by_vertex_order:
            if prev_layer.index not in layer.input_indices:
                layer.input_indices.append(prev_layer.index)
                layer.inputs.append(prev_layer.name)
                if layer.name:
                    prev_layer.outputs.append(layer.name)

    def _update_output_indices(self):
        for vertex, layer in self._vertices_to_layers.items():
            # Added support for group convolution layer, that has an external padding vertex - this makes sure we only
            # update output indices on the feature splitter itself.
            if layer.op in [LayerType.feature_splitter, LayerType.spatial_splitter, LayerType.width_splitter]:
                sorted_outputs = []
                sorted_indices = []
                successors_io_indices = vertex.get_vertex_successors_io_indices()
                for key in sorted(successors_io_indices.keys()):
                    successors = successors_io_indices[key]
                    for succ in successors:
                        if succ in self._vertices_to_layers:
                            succ_layer = self._vertices_to_layers[succ]
                            sorted_outputs.append(succ_layer.name)
                            sorted_indices.append(succ_layer.index)
                layer.outputs = sorted_outputs
                layer.output_indices = sorted_indices
                if len(layer.outputs) > len(layer.split_sizes):
                    self._update_split_indices(layer)
            else:
                layer.output_indices = []

    def _update_split_indices(self, layer):
        split_vertex = self._graph.get_vertex_by_name(layer.original_names[0])
        vertices_io_indices = split_vertex.get_vertex_successors_io_indices()
        layers_io_indices = {}
        for io_index, vertices in vertices_io_indices.items():
            for vertex in vertices:
                # vertex can hold more than one split index
                if self._vertices_to_layers[vertex].name not in layers_io_indices:
                    layers_io_indices[self._vertices_to_layers[vertex].name] = [io_index]
                else:
                    layers_io_indices[self._vertices_to_layers[vertex].name].append(io_index)

        for output in layer.outputs:
            layer.split_indices.append(layers_io_indices[output][0])
            layers_io_indices[output].pop(0)

    def _predecessors(self, vertex):
        layer = self._vertices_to_layers[vertex]
        for prev_vertex in self._graph.predecessors(vertex):
            if prev_vertex in self._vertices_to_layers:
                prev_layer = self._vertices_to_layers[prev_vertex]
                if prev_layer != layer:
                    yield prev_vertex, prev_layer

    def _calculate_valid_subgraph_scope(self):
        if self._end_node_names is None:
            for vertex in self._graph.vertices_by_name.values():
                vertex.in_valid_subgraph = True
        else:
            valid_names = [valid_orig_name(x) for x in self._end_node_names]
            end_nodes = [self._graph.get_vertex_by_valid_name(x) for x in valid_names]
            wrong_indices = [i for i, node in enumerate(end_nodes) if node is None]
            if wrong_indices:
                wrong_names = [self._end_node_names[i] for i in wrong_indices]
                err_str = "end node names" if len(wrong_names) > 1 else "end node name"
                raise MisspellNodeError(f"Unable to find {err_str}: {wrong_names}, please verify and try again.")
            for end_node in end_nodes:
                preds_queue = [end_node]
                while preds_queue:
                    current_vertex = preds_queue.pop()
                    current_vertex.in_valid_subgraph = True
                    if self._start_node_names and current_vertex.name in self._start_node_names:
                        for node in self._graph.predecessors(current_vertex):
                            node.in_valid_subgraph = True
                    else:
                        for node in self._graph.predecessors(current_vertex):
                            if node not in preds_queue and not node.in_valid_subgraph:
                                preds_queue.append(node)

    def _validate_nan_values_in_params(self):
        raise_exceptions = ["The following layers have unsupported values:"]
        for layer in list(self._layers_graph):
            layers_nan_params = []
            layers_negative_params = []
            if hasattr(layer, "kernel") and not layer.dynamic_weights and np.any(np.isnan(layer.kernel)):
                layers_nan_params.append("kernel")
            if hasattr(layer, "bias") and np.any(np.isnan(layer.bias)):
                layers_nan_params.append("bias")
            if hasattr(layer, "bn_info"):
                bn_info = layer.bn_info
                if np.any(np.isnan(bn_info.moving_mean)):
                    layers_nan_params.append("moving_mean")
                if np.any(bn_info.moving_variance < -1e-7):
                    layers_negative_params.append("moving_variance")
                if np.any(np.isnan(bn_info.moving_variance)):
                    layers_nan_params.append("moving_variance")
                if np.any(np.isnan(bn_info.beta)):
                    layers_nan_params.append("beta")
                if np.any(np.isnan(bn_info.gamma)):
                    layers_nan_params.append("gamma")
                if np.any(np.isnan(bn_info.epsilon)):
                    layers_nan_params.append("epsilon")
            if hasattr(layer, "leaky_alpha") and layer.leaky_alpha and np.any(np.isnan(layer.leaky_alpha)):
                layers_nan_params.append("leaky_alpha")
            if (
                hasattr(layer, "activation_threshold")
                and layer.activation_threshold
                and np.any(np.isnan(layer.activation_threshold))
            ):
                layers_nan_params.append("activation_threshold")
            if (
                hasattr(layer, "activation_delta_bias")
                and layer.activation_delta_bias
                and np.any(np.isnan(layer.activation_delta_bias))
            ):
                layers_nan_params.append("activation_delta_bias")
            if hasattr(layer, "prelu_slope") and layer.prelu_slope is not None and np.any(np.isnan(layer.prelu_slope)):
                layers_nan_params.append("prelu_slope")
            if hasattr(layer, "mean") and np.any(np.isnan(layer.mean)):
                layers_nan_params.append("mean")
            if hasattr(layer, "std") and np.any(np.isnan(layer.std)):
                layers_nan_params.append("std")

            if len(layers_nan_params) > 0:
                raise_exceptions.append(
                    f"Layer {layer.name} has unsupported NaN values in its "
                    f"{', '.join(layers_nan_params)}. (translated from {layer.original_names})",
                )

            if len(layers_negative_params) > 0:
                raise_exceptions.append(
                    f"Layer {layer.name} has unsupported negative values in its "
                    f"{', '.join(layers_negative_params)}. (translated from "
                    f"{layer.original_names})",
                )

        if len(raise_exceptions) > 1:
            raise UnsupportedWeightsError("\n".join(raise_exceptions))
