import tensorflow as tf

from hailo_sdk_common.tools.models_translator_helper import valid_orig_name


def tflite_infer(model_path, feed_dict, input_tensors_names=None, output_tensors_names=None):
    interpreter = tf.lite.Interpreter(model_path, experimental_preserve_all_tensors=True)
    interpreter.allocate_tensors()

    if not isinstance(feed_dict, dict):
        feed_dict = dict(zip([x["name"] for x in interpreter.get_input_details()], [tensor for tensor in feed_dict]))

    input_details = [
        input_tensor for input_tensor in interpreter.get_tensor_details() if input_tensor["name"] in feed_dict
    ]
    input_details_by_name = {x["name"]: x for x in input_details}

    for name, dataset in feed_dict.items():
        if len(input_details_by_name[name]["shape"]) == 3 and len(dataset.shape) == 4:
            dataset = dataset.squeeze(axis=1)
        interpreter.set_tensor(input_details_by_name[name]["index"], dataset)
    interpreter.invoke()

    if not output_tensors_names:
        # tests doesn't contain end node names
        output_tensors_names = [output_tensor["name"] for output_tensor in interpreter.get_output_details()]

    tensors_to_calculate = []
    tensors_to_calculate.extend(input_tensors_names if input_tensors_names else input_details_by_name.keys())
    tensors_to_calculate.extend(output_tensors_names)

    results_by_name = {}
    tensor_details_by_name = {valid_orig_name(x["name"]): x for x in interpreter.get_tensor_details()}
    for tensor_name in tensors_to_calculate:
        valid_tensor_name = valid_orig_name(tensor_name)
        if valid_tensor_name in tensor_details_by_name:
            results_by_name[valid_tensor_name] = interpreter.get_tensor(
                tensor_details_by_name[valid_tensor_name]["index"],
            )

    net_input_names = input_tensors_names if input_tensors_names else list(input_details_by_name)
    valid_net_inputs = [valid_orig_name(x) for x in net_input_names]
    net_input = {name: results_by_name[name] for name in valid_net_inputs}

    valid_output_tensors_names = [valid_orig_name(x) for x in output_tensors_names]
    results_by_name = {x: y for x, y in results_by_name.items() if x in valid_output_tensors_names}
    return net_input, results_by_name
