#!/usr/bin/env python
import copy
import os

import networkx as nx
import numpy as np
import tensorflow as tf

from hailo_sdk_client.model_translator.exceptions import (
    CantFindPReluSlopeError,
    UnexpectedNodeError,
    UnsupportedModelError,
)
from hailo_sdk_client.model_translator.graph_lookup import BwdChainNode, look_for_node
from hailo_sdk_client.model_translator.tf_translator.tf_graph import TFGraph, TFGraphNode
from hailo_sdk_common.hailo_nn.hn_layers import BatchNormValues
from hailo_sdk_common.logger.logger import default_logger


class TF2GraphNode(TFGraphNode):
    def __init__(self, node_proto, graph):
        super().__init__(node_proto, graph)

    def get_vertex_variable_keys(self):
        possible_keys = []
        orig_restored = {x.replace("/Read/ReadVariableOp", ""): x for x in self._graph.values}
        matches = {}
        variable_names = []
        vertices_names = []

        # edge case for keras optimization that replaced DWConv with SeparableConv consisting of
        # depthwise + pointwise (conv1x1)
        name_for_match = self.name
        if "separable" in self.name and "depthwise" not in self.name:
            name_for_match = "/".join(self.name.split("/")[0:-1]) + "/pointwise"

        for x, y in orig_restored.items():
            variable_name = x.split("/")[-1]
            vertex_name = x.replace(variable_name, "")
            variable_names.append(variable_name)
            vertices_names.append(vertex_name)
            matches.update({vertex_name: len(vertex_name) if vertex_name in name_for_match else 0})

        for i, y in enumerate(orig_restored.values()):
            variable_name = variable_names[i]
            vertex_name = vertices_names[i]

            if vertex_name in self.name:
                is_best_match = not any(
                    other_match > matches[vertex_name]
                    for other_name, other_match in matches.items()
                    if other_name in name_for_match
                )
                is_conv = self.op in ["Conv2D", "Conv2DBackpropInput"] and variable_name in [
                    "weights",
                    "kernel",
                    "pointwise_kernel",
                ]
                is_dwconv = self.op in ["DepthwiseConv2dNative"] and variable_name in [
                    "weights",
                    "kernel",
                    "depthwise_weights",
                    "depthwise_kernel",
                ]
                is_bias = self.op in ["BiasAdd", "AddV2"] and variable_name in ["bias", "biases"]
                is_bn = self.op in ["FusedBatchNorm", "FusedBatchNormV3"]
                is_dense = self.op in ["MatMul"] and variable_name in ["weights", "kernel"]
                is_einsum = self.op in ["Einsum"] and variable_name in ["kernel"]
                is_prelu = self.op in ["Relu"] and variable_name in ["alpha"]

                if is_best_match and (is_conv or is_dwconv or is_bias or is_bn or is_dense or is_einsum or is_prelu):
                    # edge case: prefer only one variable name for depthwise kernel when ambiguous
                    if is_dwconv and len(possible_keys) > 0 and "depthwise" in variable_name:
                        possible_keys = [y]
                    else:
                        possible_keys.append(y)

        return possible_keys

    def get_layer_var_data(self):
        values = None
        possible_keys = self.get_vertex_variable_keys()
        if possible_keys:
            values = self.graph.values[possible_keys[0]]
        else:
            # try to fallback to constant
            possible_name = "filter" if self.op in ["DepthwiseConv2dNative"] else "ReadVariableOp"
            node = look_for_node(self._graph, self, [BwdChainNode(op="Const", name=possible_name)])
            if node is not None:
                values = tf.make_ndarray(node._info.attr["value"].tensor)

        if values is None:
            raise UnsupportedModelError(f"Couldn't find weight variable for layer {self.name}")

        return copy.deepcopy(values), values.shape

    def get_bn_info(self):
        possible_keys = self.get_vertex_variable_keys()
        moving_mean, moving_var, gamma, beta = [None, None, None, None]
        for key in possible_keys:
            val = self.graph.values[key]
            if "moving_mean" in key:
                moving_mean = val
            elif "moving_var" in key:
                moving_var = val
            elif "gamma" in key:
                gamma = val
            elif "beta" in key:
                beta = val

        epsilon = self._info.attr["epsilon"].f if "epsilon" in self._info.attr else None
        if any(x is None for x in [moving_mean, moving_var, beta, epsilon]):
            raise UnexpectedNodeError(f"FusedBatchNorm layer {self.name} is missing one or more weight variables.")
        if gamma is None:
            default_logger().debug(f"Gamma value not found from node {self.name}. Assumed gamma = 1.0.")
            gamma = np.ones(shape=beta.shape, dtype=np.float32)

        return BatchNormValues(
            moving_mean=moving_mean,
            moving_variance=moving_var,
            gamma=gamma,
            beta=beta,
            epsilon=epsilon,
        )

    def get_dilations(self, is_dilations_s2b=False):
        dilations = super().get_dilations(is_dilations_s2b=is_dilations_s2b)
        return dilations if dilations else [1, 1, 1, 1]

    def get_prelu_slope(self):
        possible_keys = self.get_vertex_variable_keys()
        if possible_keys:
            return np.array(self._graph.values[possible_keys[0]], dtype=float).flatten()

        raise CantFindPReluSlopeError(self.name)


class TF2Graph(TFGraph):
    def __init__(self, raw_graph_proto, values):
        super(TFGraph, self).__init__(raw_graph_proto, values)
        tf_graph_def = []
        self._is_nchw = False

        for node in tf_graph_def:
            # edge case for model input
            if node.op == "Placeholder" and not node.input:
                node.name = node.name.replace(".", "_")

            vertex = TF2GraphNode(node, self)
            self.add_node(vertex)
            self.add_vertex_by_name(vertex)
            if vertex.is_nchw():
                self._is_nchw = True

        for node in tf_graph_def:
            for input_node in node.input:
                input_node_src = input_node.split(":")[0]
                if input_node_src not in self._vertices_by_name:
                    continue
                self.add_edge(self._vertices_by_name[input_node_src], self._vertices_by_name[node.name])

    @property
    def variables_names_lookup(self):
        return self._variables_names_lookup

    @property
    def variables_names_reverse_lookup(self):
        return {y: x for x, y in self._variables_name_lookup.items()}

    def visualize(self, filename_prefix):
        other_graph = nx.DiGraph()
        for node in self.nodes:
            other_graph.add_node(node)
        for src_node, dst_node in self.edges:
            if src_node in other_graph.nodes and dst_node in other_graph.nodes:
                other_graph.add_edge(src_node, dst_node)

        dot_path = f"{filename_prefix}.dot"
        svg_path = f"{filename_prefix}.svg"
        nx.drawing.nx_agraph.write_dot(other_graph, dot_path)
        os.system(f'dot -Tsvg "{dot_path}" -o "{svg_path}"')
