#!/usr/bin/env python
import difflib
import os
from contextlib import suppress

import numpy as np
import tflite
from google.protobuf.message import DecodeError

from hailo_sdk_client.exposed_definitions import NNFramework
from hailo_sdk_client.runner.exceptions import InvalidParserInputException
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError

TF_OPTIONAL_EXTENSIONS = ".tflite"

SUPPORTED_TFLITE_DTYPES = {
    tflite.TensorType.UINT8: np.uint8,
    tflite.TensorType.INT8: np.int8,
    tflite.TensorType.INT16: np.int16,
    tflite.TensorType.INT32: np.int32,
    tflite.TensorType.INT64: np.int64,
    tflite.TensorType.FLOAT16: np.float16,
    tflite.TensorType.FLOAT32: np.float32,
    tflite.TensorType.FLOAT64: np.float64,
}


def detect_tf_nn_framework(model_path):
    nn_framework = None
    graph = None
    values = None
    node_names = None

    if os.path.splitext(model_path)[-1] == ".tflite":
        with suppress(DecodeError):
            graph, tensors, node_names = get_tflite_model_meta_data(model_path)
            values = {}
            for tensor in tensors:
                buffer = graph.Buffers(tensor.Buffer())
                if buffer.DataLength() > 0:
                    key = tensor.Name().decode("utf-8")
                    value = np.frombuffer(buffer.DataAsNumpy(), SUPPORTED_TFLITE_DTYPES[tensor.Type()])
                    if isinstance(tensor.ShapeAsNumpy(), int) and tensor.ShapeAsNumpy() == 0:
                        values[key] = value
                    else:
                        values[key] = np.reshape(value, tensor.ShapeAsNumpy())

            nn_framework = NNFramework.TENSORFLOW_LITE

    if any(x is None for x in [nn_framework, graph, values, node_names]):
        raise UnsupportedModelError(
            "Failed to analyze TF model, it appears the model provided is in an unsupported format. "
            "If you are using a TF1.x model (such as .ckpt or .pb), or a TF2.x model "
            "(such as .h5 or saved_model.pb), please refer to the user guide for details on how to "
            "convert to TensorFlow Lite format.",
        )

    return nn_framework, graph, values, node_names


def suggest_other_node_names(name, node_names, prefix):
    name_splits = name.split("/")
    scoped_matches = [
        x
        for x in node_names
        if all(y in x.split("/") for y in name_splits) and (len(x.split("/")) - len(name_splits) < 2)
    ]
    diff_matches = difflib.get_close_matches(name, node_names, n=2, cutoff=0.8)
    included_matches = [x for x in node_names if name in x]

    suggested_matches = []
    suggested_matches.extend(scoped_matches)
    suggested_matches.extend(diff_matches)
    suggested_matches.extend(included_matches)
    suggested_matches = set(suggested_matches)

    msg_str = f"{prefix} node {name} wasn't found in TF model."
    if suggested_matches:
        msg_str += f" Did you mean one of these? {suggested_matches}"
    raise InvalidParserInputException(msg_str)


def get_tflite_model_meta_data(model_path):
    with open(model_path, "rb") as f:
        buf = bytearray(f.read())

    graph = tflite.Model.GetRootAsModel(buf, 0)
    sub_graph = graph.Subgraphs(0)
    tensors = [sub_graph.Tensors(i) for i in range(sub_graph.TensorsLength())]
    ops = [sub_graph.Operators(i) for i in range(sub_graph.OperatorsLength())]

    input_tensors = [tensors[idx] for idx in sub_graph.InputsAsNumpy()]
    input_node_names = [tensor.Name().decode("utf-8") for tensor in input_tensors]
    ops_names = [tensors[op.OutputsAsNumpy()[0]].Name().decode("utf-8") for op in ops]
    node_names = input_node_names + ops_names

    return graph, tensors, node_names
