from functools import partial

import numpy as np


class HefInferContextException(Exception):
    """Raises when the executing get_tf_graph for hef outside of the context."""


class HefInferWrapper:
    def __init__(
        self,
        infer_model,
        device,
        translate_input,
        rescale_output,
        output_names,
        network_groups=None,  # remove this 2 line when async api on hailo8 is enabled SDK-51150
        input_names=None,
        configured_infer_model=None,
    ):
        self.infer_pipeline = None  # remove this 3 lines function when async api on hailo8 is enabled SDK-51150
        self._input_names = input_names
        self._network_groups = network_groups
        self._infer_model = infer_model
        self._configured_model = configured_infer_model
        self._device = device
        self._translate_input = translate_input
        self._rescale_output = rescale_output
        self._output_names = output_names
        self._results = []

    @property
    def translate_input(self):
        return self._translate_input

    @property
    def rescale_output(self):
        return self._rescale_output

    @property
    def device(self):
        return self._device

    @device.setter
    def device(self, device):
        self._device = device

    @property
    def infer_model(self):
        return self._infer_model

    @infer_model.setter
    def infer_model(self, infer_model):
        self._infer_model = infer_model

    def binding_callback(self, completion_info, bindings):
        if completion_info.exception:
            raise completion_info.exception
        self._results.append(
            [
                bindings[0].output(output_name).get_buffer(tf_format=self._infer_model.output(output_name).is_nms)
                for output_name in self._output_names
            ],
        )

    def sync_tf_infer(self, *args):
        if self.infer_pipeline is None:
            raise HefInferContextException("Please use hailo_export.session inside the hef_infer_context")
        input_data = {}
        input_evaled = args
        for name, input_val in zip(self._input_names, input_evaled):
            input_data[name] = input_val

        result = self.infer_pipeline.infer(input_data)
        return [result[output_layer] for output_layer in self._output_names]

    def tf_infer(self, *args):
        if self.infer_model:
            # async inference
            return self.async_tf_infer(*args)
        # sync inference
        return self.sync_tf_infer(*args)

    def async_tf_infer(self, *args):
        # clears the results from the previous run
        self._results = []
        num_of_frames = len(args)
        batch_sizes = [input_img.shape[0] for input_img in args]

        # the input buffers, it's a list of list of dictionaries [number_of frames][batch_size][input_name: input_image]
        # assigns the input images to the input buffers
        input_buffers = [
            [
                {name: args[i][batch_size].astype(np.float32) for name in self._infer_model.input_names}
                for batch_size in range(batch_sizes[i])
            ]
            for i in range(num_of_frames)
        ]

        # assigns empty output buffers for each output
        output_buffers = [
            [
                {
                    name: np.empty(self._infer_model.output(name).shape, np.float32)
                    for name in self._infer_model.output_names
                }
                for _ in range(batch_sizes[i])
            ]
            for i in range(num_of_frames)
        ]

        # create bindings for each input/output buffer
        bindings = [
            self._configured_model.create_bindings(input_buffers[i][j], output_buffers[i][j])
            for i in range(num_of_frames)
            for j in range(batch_sizes[i])
        ]

        for binding in bindings:
            self._configured_model.wait_for_async_ready()
            last_job = self._configured_model.run_async([binding], partial(self.binding_callback, bindings=[binding]))
        last_job.wait(10000)

        num_of_outputs = len(self._output_names)
        # concatenate the results from all the frames
        results = [None] * num_of_outputs
        for i in range(num_of_outputs):
            results[i] = np.concatenate(
                [np.expand_dims(self._results[j][i], axis=0) for j in range(batch_sizes[0])],
                axis=0,
            ).astype(np.float32)
        return results
