import os
import random
from functools import partial
from importlib.util import find_spec

import numpy as np
import tensorflow as tf
from PIL import Image

from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.logger.logger import default_logger


def random_data_generator(input_layers_names, shapes_dict, max_val, img_lst, dir_path, seed):
    rng = np.random.default_rng(seed)
    while True:
        random_data = {}
        for input_layer_name in input_layers_names:
            input_shape = shapes_dict[input_layer_name]

            if img_lst and dir_path:
                idx = random.randint(0, len(img_lst) - 1)
                img = np.array(Image.open(os.path.join(dir_path, img_lst[idx]))) / 255.0 * max_val
                img = np.resize(img, input_shape)
            else:
                img = rng.random(input_shape) * max_val
            random_data[input_layer_name] = img
        yield random_data, {}


def select_single_input(x, y, l_name):
    return x[l_name], y.get(l_name, {})


def get_random_calibset(
    input_layers_names, max_val, shapes_dict, img_lst, dir_path, signature_data, dataset_size=None, seed=1
):
    reconstructed_signature = {
        layer_name: tf.TensorSpec(shape=spec["shape"], dtype=spec["dtype"])
        for layer_name, spec in signature_data.items()
    }
    signature = (reconstructed_signature, {})
    dataset = tf.data.Dataset.from_generator(
        partial(random_data_generator, input_layers_names, shapes_dict, max_val, img_lst, dir_path, seed),
        output_signature=signature,
    )

    if len(input_layers_names) == 1:
        l_name = input_layers_names[0]
        dataset = dataset.map(partial(select_single_input, l_name=l_name))

    dataset = dataset.apply(tf.data.experimental.assert_cardinality(tf.data.experimental.INFINITE_CARDINALITY))

    if dataset_size:
        dataset = dataset.take(dataset_size)

    return dataset


def generate_random_calib_set(
    model, max_val, as_callback, dataset_size=None, seed=1, img_lst=None, dir_path=None, shapes_dict=None
):
    input_layers = model.get_input_layers()
    if not shapes_dict:
        shapes_dict = {}
    signature_data = {}
    input_layers_names = []

    for layer in input_layers:
        input_layers_names.append(layer.name)
        if layer.name in shapes_dict:
            input_shape = shapes_dict[layer.name]
        else:
            input_shape = layer.input_shape[1:]
            if layer.transposed and len(input_shape) == 3:
                input_shape = [input_shape[1], input_shape[0], input_shape[2]]
            shapes_dict[layer.name] = input_shape

        # Replace TensorSpec with a serializable dictionary
        signature_data[layer.name] = {
            "shape": input_shape,
            "dtype": "float32",  # Use a string representation for dtype
        }

    if as_callback:
        return partial(
            get_random_calibset,
            input_layers_names,
            max_val,
            shapes_dict,
            img_lst,
            dir_path,
            signature_data,
            dataset_size,
            seed,
        )
    else:
        return get_random_calibset(
            input_layers_names, max_val, shapes_dict, img_lst, dir_path, signature_data, dataset_size, seed
        )


def get_random_calib_dataset(hn_model, calib_random_max=1):
    input_layers = hn_model.get_input_layers()
    # create a dictionary of input layer names to their shapes
    input_layers_conversions = []
    names_to_shapes = {}
    for layer in input_layers:
        # check if the input layer has a format conversion layer successor
        conversion = any(
            succ
            for succ in hn_model.successors(layer)
            if succ.op == LayerType.format_conversion and succ.in_emulation_graph
        )
        input_layers_conversions.append(conversion)  # used for RGB condition below
        in_shape = hn_model.get_input_shapes(ignore_conversion=(not conversion), specific_lname=layer.name)[0][1:]
        if layer.transposed and len(in_shape) == 3:
            in_shape = [in_shape[1], in_shape[0], in_shape[2]]
        names_to_shapes[layer.name] = in_shape

    img_lst, rgb_dir_path = [None] * 2
    if all(len(layer.input_shape) == 4 and layer.input_shape[-1] == 3 for layer in input_layers) and not any(
        input_layers_conversions
    ):
        sdk_client = find_spec("hailo_sdk_client")
        rgb_dir_path = os.path.join(os.path.dirname(sdk_client.origin), "../hailo_tutorials/data/")
        if os.path.exists(os.path.abspath(rgb_dir_path)):
            img_lst = [img_name for img_name in os.listdir(rgb_dir_path) if os.path.splitext(img_name)[1] == ".jpg"]
    if img_lst:
        default_logger().info(
            "Found model with 3 input channels, using real RGB images for calibration instead "
            "of sampling random data.",
        )
    return generate_random_calib_set(
        hn_model,
        calib_random_max,
        as_callback=True,
        img_lst=img_lst,
        dir_path=rgb_dir_path,
        shapes_dict=names_to_shapes,
    )
