import onnx
import onnxsim
from onnx.tools.update_model_dims import update_inputs_outputs_dims

from hailo_sdk_common.logger.logger import default_logger


class ONNXModelMetadataExtractor:
    """
    Given onnx model and two lists of start and end node names, returns two models and ONNX model metadata:
        preprocess_model: subgraph of the given ONNX model starting in the original model input nodes and ends in the
            input of the start nodes. None if start_node_names is None.
        postprocess_model: subgraph of the given ONNX model where the inputs are the outputs of end nodes and the
            outputs are the outputs of the original model. None if end_node_names is None.
        original_model_meta:
            ir_version: The given ONNX model intermediate representation version.
            opset_imports: The given ONNX model opset imports.
            preprocess_io_map: mapping between the pre-process model output node names to their outputs in the original
                graph given as start_node_names. {} if start_node_names is None.
            postprocess_io_map: mapping between the post-process model input node names to their input in the original
                graph given as end_node_names. {} if end_node_names is None.
    """

    UNSUPPORTED_OPS = ["Loop", "If"]

    def __init__(self, model, start_node_names, end_node_names):
        self._logger = default_logger()
        self._model = model
        self._start_node_names = start_node_names
        self._end_node_names = end_node_names
        self._errors = []
        self._real_inputs = None
        self._real_input_names = None
        self._output_names = None
        self._preprocess_io_map = {}
        self._postprocess_io_map = {}
        self._input_dtype = {}
        self._output_dtype = {}
        self._preprocess_inputs = set()
        self._preprocess_outputs = set()
        self._postprocess_inputs = set()
        self._postprocess_outputs = set()
        self._preprocess_model = None
        self._postprocess_model = None
        self._extractor = None

    @staticmethod
    def _get_real_onnx_inputs(model):
        initializers = [node.name for node in model.graph.initializer]
        return [node for node in model.graph.input if node.name not in initializers]

    def _simplify_model(self):
        try:
            model_simp, check = onnxsim.simplify(self._model, skip_fuse_bn=True)
            if check:
                self._model = model_simp
            else:
                error_msg = "Failed to simplify model"
                self._errors.append(error_msg)
                self._logger.debug(error_msg)

        except Exception as e:
            error_msg = f"Unable to simplify the model: {e!s}"
            self._errors.append(error_msg)
            self._logger.debug(error_msg)

    def _generate_io_mapping(self):
        for input_node in self._real_inputs:
            input_name = input_node.name
            if input_name in self._start_node_names:
                i = self._start_node_names.index(input_name)
                self._preprocess_io_map[input_name] = f"input{i}"

        for output_node in self._model.graph.output:
            output_name = output_node.name
            if output_name in self._end_node_names:
                i = self._end_node_names.index(output_name)
                self._postprocess_io_map[output_name] = f"output{i}"

        for node in self._model.graph.node:
            if node.name in self._start_node_names:
                i = self._start_node_names.index(node.name)
                self._preprocess_io_map[node.input[0]] = f"input{i}"
                self._add_preprocess_inputs(node)

            if node.name in self._end_node_names:
                i = self._end_node_names.index(node.name)
                for output in node.output:
                    self._postprocess_io_map[output] = f"output{i}"
                self._add_postprocess_outputs(node)

        model_inputs = set(self._real_input_names)
        model_outputs = set(self._output_names)

        hailo_inputs = set(self._preprocess_io_map)
        hailo_outputs = set(self._postprocess_io_map)

        self._preprocess_inputs -= hailo_inputs
        self._postprocess_outputs -= hailo_outputs

        # All outputs that are not hailo outputs or postprocess outputs
        preprocess_outputs_from_graph = model_outputs - hailo_outputs - self._postprocess_outputs
        # Take all given start nodes that are not input nodes (which can't be outputs)
        preprocess_outputs_from_start_nodes = hailo_inputs - model_inputs
        self._preprocess_outputs = preprocess_outputs_from_graph | preprocess_outputs_from_start_nodes

        # All inputs that are not hailo inputs or preprocess inputs
        postprocess_inputs_from_graph = model_inputs - hailo_inputs - self._preprocess_inputs
        # Take all given end nodes that are not output nodes (which can't be inputs)
        postprocess_inputs_from_end_nodes = hailo_outputs - model_outputs
        self._postprocess_inputs = postprocess_inputs_from_graph | postprocess_inputs_from_end_nodes

    def _add_preprocess_inputs(self, curr_output_node):
        """
        Find the input nodes of preprocess graph.
        For each given Hailo start node (which is the pre-process graph end node) we find possible paths to ONNX graph
        inputs, avoiding passing in the same node twice using BFS. When ONNX input is reachable from the Hailo
        start node we add it to pre-process output set.
        """
        visited = set()
        queue = [curr_output_node]
        visited.add(curr_output_node.name)

        while queue:
            prev_node = queue.pop(0)

            for inp in prev_node.input:
                if inp in self._real_input_names:
                    self._preprocess_inputs.add(inp)

                possible_inputs = [x for x in self._model.graph.node if any(y == inp for y in x.output)]
                for input_node in possible_inputs:
                    if input_node.name not in visited:
                        queue.append(input_node)
                        visited.add(input_node.name)

    def _add_postprocess_outputs(self, curr_input_node):
        """
        Find the output nodes of postprocess graph.
        For each given Hailo end node (which is the post-process graph start node) we find possible paths to ONNX graph
        outputs, avoiding passing in the same node twice using BFS. When ONNX output is reachable from the Hailo
        end node we add it to post-process output set.
        """
        visited = set()
        queue = [curr_input_node]
        visited.add(curr_input_node.name)

        while queue:
            next_node = queue.pop(0)

            for out in next_node.output:
                if out in self._output_names:
                    self._postprocess_outputs.add(out)

                possible_outputs = [x for x in self._model.graph.node if any(y == out for y in x.input)]
                for output_node in possible_outputs:
                    if output_node.name not in visited:
                        queue.append(output_node)
                        visited.add(output_node.name)

    def _set_dynamic_shapes(self):
        try:
            input_dims = {}
            for inp in self._model.graph.input:
                dims = [dim.dim_value for dim in inp.type.tensor_type.shape.dim]
                input_dims[inp.name] = [-1, *dims[1:]]

            output_dims = {}
            for output in self._model.graph.output:
                dims = [dim.dim_value for dim in output.type.tensor_type.shape.dim]
                output_dims[output.name] = [-1, *dims[1:]]

            self._model = update_inputs_outputs_dims(self._model, input_dims, output_dims)

        except Exception as e:
            error_msg = f"Unable to set dynamic shapes to the model: {e!s}"
            self._errors.append(error_msg)
            self._logger.debug(error_msg)

    def _initialize_extractor(self):
        try:
            self._extractor = onnx.utils.Extractor(self._model)
        except Exception as e:
            error_msg = f"Unable to initialize ONNX model extractor: {e!s}"
            self._errors.append(error_msg)
            self._logger.debug(error_msg)

    def _extract_model(self, inputs, outputs):
        if len(inputs) > 0 and len(outputs) > 0:
            return self._extractor.extract_model(inputs, outputs)

    def _get_preprocess_model(self):
        try:
            return self._extract_model(self._preprocess_inputs, self._preprocess_outputs)
        except Exception as e:
            error_msg = (
                f"Unable to extract pre-process model from {self._preprocess_inputs} to "
                f"{self._preprocess_outputs}: {e!s}"
            )
            self._errors.append(error_msg)
            self._logger.debug(error_msg)

    def _get_postprocess_model(self):
        try:
            return self._extract_model(self._postprocess_inputs, self._postprocess_outputs)
        except Exception as e:
            error_msg = (
                f"Unable to extract post-process model from {self._postprocess_inputs} to "
                f"{self._postprocess_outputs}: {e!s}"
            )
            self._errors.append(error_msg)
            self._logger.debug(error_msg)

    def _set_input_dtype(self):
        hailo_input_nodes = self._real_inputs
        if self._preprocess_model:
            hailo_input_nodes.extend(list(self._preprocess_model.graph.output))

        for input_node in hailo_input_nodes:
            if input_node.name in self._preprocess_io_map:
                self._input_dtype[self._preprocess_io_map[input_node.name]] = input_node.type.tensor_type.elem_type

    def _set_output_dtype(self):
        hailo_output_nodes = list(self._model.graph.output)
        if self._postprocess_model:
            hailo_output_nodes.extend(self._get_real_onnx_inputs(self._postprocess_model))

        for output_node in hailo_output_nodes:
            if output_node.name in self._postprocess_io_map:
                self._output_dtype[self._postprocess_io_map[output_node.name]] = output_node.type.tensor_type.elem_type

    def _is_supported_model(self):
        is_supported = True

        for node in self._model.graph.node:
            if node.op_type in self.UNSUPPORTED_OPS:
                self._errors.append(f"Node {node.name} of type {node.op_type} is not supported")
                is_supported = False

        return is_supported

    def extract(self):
        is_supported_model = self._is_supported_model()

        if is_supported_model:
            self._simplify_model()
            self._real_inputs = self._get_real_onnx_inputs(self._model)
            self._real_input_names = [x.name for x in self._real_inputs]
            if not self._start_node_names:
                self._start_node_names = self._real_input_names
            self._output_names = [x.name for x in self._model.graph.output]
            if not self._end_node_names:
                self._end_node_names = self._output_names
            self._set_dynamic_shapes()
            self._generate_io_mapping()
            self._initialize_extractor()
            if self._extractor:
                self._preprocess_model = self._get_preprocess_model()
                self._postprocess_model = self._get_postprocess_model()
            self._set_input_dtype()
            self._set_output_dtype()

        original_model_meta = {
            "ir_version": self._model.ir_version,
            "opset_imports": [[opset_id.domain, opset_id.version] for opset_id in self._model.opset_import],
            "preprocess_io_map": self._preprocess_io_map,
            "inverse_postprocess_io_map": self._postprocess_io_map,
            "input_dtype": self._input_dtype,
            "output_dtype": self._output_dtype,
            "errors": self._errors,
            "is_supported_model": is_supported_model,
        }

        return self._preprocess_model, self._postprocess_model, original_model_meta
