from collections import OrderedDict

import numpy as np

from hailo_sdk_client.exposed_definitions import InferenceContext
from hailo_sdk_client.runner.client_runner import ClientRunner
from hailo_sdk_client.tools.frameworks_inference.onnx_inference_helper import (
    reshape_by_output_format,
    run_onnx_runtime_inference,
)
from hailo_sdk_client.tools.frameworks_inference.tflite_inference_helper import tflite_infer
from hailo_sdk_client.tools.tf_proto_helper import (
    TF_OPTIONAL_EXTENSIONS,
    detect_tf_nn_framework,
)
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.tools.models_translator_helper import valid_orig_name

logger = default_logger()

DEFAULT_SPARSITY_THRESHOLD = 0.9


class DFCInferenceToolExceptionError(Exception):
    pass


def get_dataset_range(runner):
    """Returns a list: [mean, std, low, high]"""
    info = [0.01, 0.25, -1.0, 1.0]
    data_type = np.float32
    hailo_nn = runner.get_hn_model()
    params = runner.get_params()
    for layer in hailo_nn:
        if layer.name in params and "kernel" in params[layer.name]:
            kernel = params[layer.name]["kernel"]
            low = np.min(kernel)
            high = np.max(kernel)
            info = [(high + low) / 2, (high - low) / 4, low, high]
            data_type = kernel.dtype
            break

    return info, data_type


def get_random_dataset(network_size, mean=127.0, std=80.0, low=0.0, high=255.0, seed=42, dtype=np.float32):
    shape = [d if d is not None and d > 0 else 1 for d in network_size]
    logger.debug(f"Randomize data with seed: {seed}")
    rng = np.random.default_rng(seed)
    return np.clip(rng.normal(loc=mean, scale=std, size=shape), low, high).astype(dtype)


def is_sparse(array, allowed_tolerance=0.01, sparsity_threshold=0.9):
    """
    Determine if an array is sparse based on a given sparsity_threshold, with respect to the diff required tolerance.
    """
    total_elements = array.size
    zero_like_elements = np.sum(array < allowed_tolerance)
    sparsity = zero_like_elements / total_elements
    return sparsity >= sparsity_threshold


class DFCFrameworksInferenceTool:
    def __init__(self, runner, model_path=None, net_name="model", start_node_names=None, end_node_names=None):
        self._runner = runner
        self._framework = (
            (detect_tf_nn_framework(model_path, net_name)[0] if model_path.endswith(TF_OPTIONAL_EXTENSIONS) else "onnx")
            if model_path
            else None
        )
        self._model_path = model_path
        self._start_node_names = start_node_names
        self._end_node_names = end_node_names
        self._comparable_layers_keys = []
        self._framework_dataset = {}
        self._hn_dataset = {}

    @property
    def runner(self):
        return self._runner

    @runner.setter
    def runner(self, runner):
        self._runner = runner

    @property
    def framework(self):
        return self._framework

    @property
    def model_path(self):
        return self._model_path

    @property
    def start_node_names(self):
        return self._start_node_names

    @start_node_names.setter
    def start_node_names(self, start_node_names):
        self._start_node_names = start_node_names

    @property
    def end_node_names(self):
        return self._end_node_names

    @end_node_names.setter
    def end_node_names(self, end_node_names):
        self._end_node_names = end_node_names

    @property
    def framework_dataset(self):
        return self._framework_dataset

    @framework_dataset.setter
    def framework_dataset(self, framework_dataset):
        self._framework_dataset = framework_dataset

    @property
    def hn_dataset(self):
        return self._hn_dataset

    @hn_dataset.setter
    def hn_dataset(self, hn_dataset):
        self._hn_dataset = hn_dataset

    def generate_framework_dataset(self, framework_dataset=None, output_format=None):
        # create a dataset from the model input layers
        if self.framework_dataset:
            return self.framework_dataset

        runner = self.runner if isinstance(self.runner, ClientRunner) else self.runner.runner
        info, data_type = get_dataset_range(runner)
        input_layers = runner.get_hn_model().get_input_layers()

        if framework_dataset:
            if isinstance(framework_dataset, dict):
                if len(framework_dataset) != len(input_layers):
                    raise DFCInferenceToolExceptionError(
                        "Length of dataset is different than the number of input layers."
                    )

                for original_name, data in framework_dataset.items():
                    layer = next(
                        (input_layer for input_layer in input_layers if original_name in input_layer.original_names),
                        None,
                    )
                    if layer:
                        self.framework_dataset[layer.original_names[0]] = data.astype(data_type)
                    if not layer:
                        raise DFCInferenceToolExceptionError(
                            f"Could not find input layer derived from {original_name}."
                        )
            else:
                if len(input_layers) > 1:
                    raise DFCInferenceToolExceptionError(
                        "For multiple input tensors, the dataset should be given as a "
                        "dictionary, where each key is the original name of the input node.",
                    )
                data = framework_dataset[0] if isinstance(framework_dataset, list) else framework_dataset
                self.framework_dataset[input_layers[0].name] = self.convert_data_to_hn_format(
                    runner, original_name, data, output_format
                ).astype(data_type)
        else:
            if self.framework == "onnx":
                input_layers_shapes = runner.original_model_meta.get("start_nodes_shapes")
            else:
                input_layers_shapes = {layer.original_names[0]: layer.input_shape for layer in input_layers}
            for input_layer_name, input_layer_shape in input_layers_shapes.items():
                data = get_random_dataset(
                    network_size=[1, *input_layer_shape[1:]],
                    mean=info[0],
                    std=info[1],
                    low=info[2],
                    high=info[3],
                    dtype=data_type,
                )
                self.framework_dataset[input_layer_name] = data

        return self.framework_dataset

    def convert_data_to_hn_format(self, runner, framework_layer_name, framework_data, output_format=None):
        # convert the data to Hailo format, required only for onnx framework
        if self.framework == "onnx":
            parsing_report = runner.original_model_meta.get("parsing_report")
            output_format = output_format if output_format else parsing_report.output_format if parsing_report else None
            framework_data = reshape_by_output_format(framework_data, output_format.get(framework_layer_name))
        return framework_data

    def generate_hn_dataset(self, runner, framework_dataset=None, model_path=None, output_format=None):
        # create a dataset from the runner input layers
        if self.hn_dataset:
            return self.hn_dataset

        runner = runner if isinstance(runner, ClientRunner) else runner.runner
        _, data_type = get_dataset_range(runner)
        input_layers = runner.get_hn_model().get_input_layers()
        hn_dataset = {}

        if not self.framework_dataset:
            self.framework_dataset = framework_dataset if framework_dataset else self.generate_framework_dataset()

        if isinstance(self.framework_dataset, dict):
            if len(self.framework_dataset) != len(input_layers):
                raise DFCInferenceToolExceptionError("Length of dataset is different than the number of input layers.")

            for original_name, data in self.framework_dataset.items():
                layer = next(
                    (input_layer for input_layer in input_layers if original_name in input_layer.original_names), None
                )
                if layer:
                    hn_dataset[layer.name] = self.convert_data_to_hn_format(
                        runner, original_name, data, output_format
                    ).astype(data_type)
                if not layer:
                    raise DFCInferenceToolExceptionError(f"Could not find input layer derived from {original_name}.")
        else:
            if len(input_layers) > 1:
                raise DFCInferenceToolExceptionError(
                    "For multiple input tensors, the dataset should be given as a "
                    "dictionary, where each key is the original name of the input node.",
                )
            data = self.framework_dataset[0] if isinstance(self.framework_dataset, list) else self.framework_dataset
            hn_dataset[input_layers[0].name] = self.convert_data_to_hn_format(
                runner, input_layers[0].original_names[0], data, output_format
            ).astype(data_type)

        self.hn_dataset = hn_dataset
        return hn_dataset

    def _run_onnx_runtime_inference_wrapper(
        self,
        hailo_nn,
        onnx_input_data,
        model_path,
        start_node_names,
        end_node_names,
        output_format=None,
    ):
        return run_onnx_runtime_inference(
            hailo_nn=hailo_nn,
            input_dataset=onnx_input_data,
            model_path=model_path,
            start_node_names=start_node_names,
            end_node_names=end_node_names,
            output_format=output_format,
        )

    def _get_onnx_end_node_names_from_hn(self, hailo_nn):
        return [layer.original_names[-1] for layer in hailo_nn.get_real_output_layers()]

    def _get_onnx_results(self, onnx_dataset, **kwargs):
        start_node_names = self.start_node_names if hasattr(self, "start_node_names") else self.runner.start_node_names
        runner = self.runner if isinstance(self.runner, ClientRunner) else self.runner.runner
        if onnx_dataset is None:
            onnx_dataset = self.generate_framework_dataset(output_format=kwargs.get("output_format"))

        parsing_report = runner.original_model_meta.get("parsing_report")
        output_format = parsing_report.output_format if parsing_report else None
        hailo_nn = runner.get_hn_model()
        end_node_names = self._get_onnx_end_node_names_from_hn(hailo_nn)
        onnx_results = self._run_onnx_runtime_inference_wrapper(
            hailo_nn,
            onnx_dataset,
            self.model_path,
            start_node_names,
            end_node_names,
            output_format=output_format,
        )

        return hailo_nn, onnx_results, onnx_dataset, end_node_names

    def _get_tflite_results(self, dataset):
        runner = self.runner if isinstance(self.runner, ClientRunner) else self.runner.runner
        if dataset is None:
            dataset = self.generate_framework_dataset()
        hailo_nn = runner.get_hn_model()
        _, orig_names_res = tflite_infer(
            self.model_path, dataset, self.runner.start_node_names, self.runner.end_node_names
        )
        orig_names_res = {valid_orig_name(k): v for k, v in orig_names_res.items()}
        tflite_res = []
        for hn_name in hailo_nn.net_params.output_layers_order:
            output_layer = hailo_nn.get_layer_by_name(hn_name)
            tflite_res.append(orig_names_res[output_layer.original_names[-1]])

        return tflite_res

    def get_original_results(self, framework, dataset=None, **kwargs):
        if framework == "onnx":
            return self._get_onnx_results(
                onnx_dataset=dataset,
                **kwargs,
            )
        if framework == "tflite":
            return self._get_tflite_results(dataset=dataset)

        raise ValueError(f"Unsupported framework: {framework}")

    def get_dfc_inference_results(self, runner, infer_context=InferenceContext.SDK_NATIVE, hn_dataset=None):
        runner = runner if isinstance(runner, ClientRunner) else runner.runner
        with runner.infer_context(infer_context) as ctx:
            native_res = runner.infer(ctx, hn_dataset)
            return [native_res] if isinstance(native_res, np.ndarray) else native_res

    def _get_results_for_comparison(self, framework=None, model_path="", **kwargs):
        native_results = kwargs.get("native_results")
        original_results = kwargs.get("original_results")
        if not native_results and not original_results:
            # calling original results first to support the agenda of always create framework dataset first
            if not original_results:
                original_framework = framework if framework else self.framework
                original_model_path = model_path if model_path else self.model_path
                # call generate framework dataset if needed?
                original_results = self.get_original_results(
                    original_framework,
                    **kwargs,
                )
            if not native_results:
                # given as argument in order to use the same random dataset for both results
                hn_dataset = kwargs.get(
                    "hn_dataset",
                    self.generate_hn_dataset(
                        self.runner, model_path=original_model_path, output_format=kwargs.get("output_format")
                    ),
                )
                native_results = self.get_dfc_inference_results(self.runner, hn_dataset=hn_dataset)

        return native_results, original_results

    def create_comparison_dictionaries(
        self,
        framework=None,
        model_path="",
        allowed_tolerance=1e-2,
        **kwargs,
    ):
        """
        Args:
            framework: the framework of the original model
            model_path: the path to the original model
            allowed_tolerance: the allowed tolerance for the comparison
        Returns:
            native_output_results: the results of the DFC inference, can be native, fp, quantized, etc.
            original_results: the results of the original model inference, or any expected results (quantized, etc.).
            result_diffs: the differences between the results
            non_tested_layers: layers that were not tested
        """
        results, expected_results = self._get_results_for_comparison(
            framework=framework, model_path=model_path, **kwargs
        )
        non_tested_layers = list(results)
        result_diffs = OrderedDict()

        for current_result_key, expected_result in expected_results.items():
            if (
                current_result_key in results
                and current_result_key in non_tested_layers
                and current_result_key in self._comparable_layers_keys
            ):
                result = results[current_result_key]
                if not isinstance(result, np.ndarray):
                    result = result.numpy()
                if expected_result.shape == result.shape:
                    if expected_result.size == 0 and result.size == 0:
                        # empty tensors may come from nms models, for example
                        result_diffs[current_result_key] = (True, 0, result - expected_result)
                        non_tested_layers.remove(current_result_key)
                        continue
                    if expected_result.size == 0 or result.size == 0:
                        empty_tensor = "original" if expected_result.size == 0 else "native"
                        raise DFCInferenceToolExceptionError(
                            f"{empty_tensor} result is empty, can't compare results.",
                        )
                    diff_tensor = result - expected_result
                    diff_tensor = (
                        diff_tensor if isinstance(diff_tensor, np.ndarray) else np.array(diff_tensor)
                    )  # conversion for np.where operation in case of EagerTensor object
                    diff_tensor[np.where(np.abs(diff_tensor) < allowed_tolerance)] = 0
                    relative_diff = np.abs(diff_tensor / (expected_result + 1e-10))
                    relative_diff_scalar = np.sum(np.abs(expected_result - result)) / (
                        np.sum(np.abs(expected_result)) + 1e-10
                    )
                    # decision is based on the tensor energy when the tensor is sparse, elementwise otherwise
                    # if the relative_diff tensor is all zeros, we can't rely on the tensor energy
                    rely_on_tensor_energy = is_sparse(
                        relative_diff,
                        allowed_tolerance=allowed_tolerance,
                        sparsity_threshold=kwargs.get("sparsity_threshold", DEFAULT_SPARSITY_THRESHOLD),
                    ) and not np.all(relative_diff == 0)
                    is_same = (
                        relative_diff_scalar <= allowed_tolerance
                        if rely_on_tensor_energy
                        else np.all(relative_diff <= allowed_tolerance)
                    )
                    result_diffs[current_result_key] = (
                        is_same,
                        relative_diff_scalar if rely_on_tensor_energy else np.max(relative_diff),
                        diff_tensor,
                    )
                    non_tested_layers.remove(current_result_key)

        return results, expected_results, result_diffs, non_tested_layers

    def dfc_validate_test_results(
        self,
        framework=None,
        model_path="",
        allowed_tolerance=1e-2,
        **kwargs,
    ):
        """
        General validation function for comparing the results of the original model inference to the results of the DFC.
        """
        native_output_results, original_results, result_diffs, _ = self.create_comparison_dictionaries(
            framework,
            model_path,
            allowed_tolerance=allowed_tolerance,
            **kwargs,
        )

        if len(native_output_results) == len(original_results):
            is_all_same = True
            for result_diff in result_diffs.values():
                is_all_same = is_all_same and result_diff[0]

            if is_all_same:
                logger.info(
                    "Model translation was verified successfully by comparing inference results "
                    f"from Hailo native model emulation vs original {framework} runtime.",
                )
            else:
                logger.warning(
                    f"Model translation failed verification, found difference when comparing inference "
                    f"results and Hailo native model emulation vs original {framework} runtime. For common "
                    f"reasons, please refer to the User Guide: 'Reasons and remedies for differences in the "
                    f"parsed model'.",
                )
        else:
            logger.warning(
                "Could not compare results due to different number of outputs. Hailo model had "
                f"{len(native_output_results)} outputs and {framework} model had {len(original_results)} outputs.",
            )


class HailoParserInferenceTool(DFCFrameworksInferenceTool):
    def _get_onnx_results(self, onnx_dataset, **kwargs):
        hailo_nn, onnx_results, onnx_input_data, _ = super()._get_onnx_results(
            onnx_dataset=onnx_dataset,
            **kwargs,
        )
        # batch size determined by the first input tensor shape[0]
        batch = (
            onnx_input_data[next(iter(onnx_input_data.keys()))].shape[0]
            if isinstance(onnx_input_data, dict)
            else onnx_input_data[0].shape[0]
        )
        output_shapes = [[batch] + x.output_shape[1:] for x in hailo_nn.get_output_layers()]
        return [np.reshape(onnx_result, output_shape) for onnx_result, output_shape in zip(onnx_results, output_shapes)]

    def _get_results_for_comparison(self, framework=None, model_path="", **kwargs):
        native_output_results, original_results = super()._get_results_for_comparison(
            framework=framework,
            model_path=model_path,
            **kwargs,
        )
        layers_keys = [f"out_{i}" for i in range(len(native_output_results))]
        native_output_results = dict(zip(layers_keys, native_output_results))
        original_results = dict(zip(layers_keys, original_results))
        native_results = native_output_results
        self._comparable_layers_keys = layers_keys

        return native_results, original_results
