import copy

import onnx
from onnx.shape_inference import infer_shapes

from hailo_sdk_common import get_version
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.logger.logger import default_logger


class UnsupportedHailoRuntimeException(Exception):
    pass


class HailoONNXModelComposer:
    """
    Given data from the original model and HEF, construct ONNX model that can be inferred with onnx-runtime when
    pre-process model is a subgraph in the start of the model, HEF wrapper in HailoOp node in the middle and postprocess
    model in the end.
    """

    PREPROCESS_PREFIX = "pre"
    POSTPROCESS_PREFIX = "post"
    HAILO_OP_NAME = "HailoOp"
    HAILO_OP_DOMAIN = "ai.hailo"
    HAILO_EXECUTION_PROVIDER_NAME = "HailoExecutionProvider"
    UNSUPPORTED_DTYPES = [
        onnx.TensorProto.STRING,
        onnx.TensorProto.BOOL,
        onnx.TensorProto.COMPLEX64,
        onnx.TensorProto.COMPLEX128,
    ]
    LAYER_TO_DEFAULT_DTYPE = {
        **{layer_type: onnx.TensorProto.FLOAT for layer_type in LayerType},
        **{
            LayerType.argmax: onnx.TensorProto.INT64,
        },
    }

    def __init__(self, hailo_platform, hef_data, preprocess_model, postprocess_model, original_model_meta, hn):
        self._logger = default_logger()
        self._hailo_platform = hailo_platform
        self._hef_data = hef_data
        self._preprocess_model = preprocess_model
        self._postprocess_model = postprocess_model
        self._original_model_meta = original_model_meta
        self._hn = hn

    def _get_shape_and_format_order(self, shape, layer_type=None):
        if len(shape) == 3:
            # [H, W, C] -> [-1, C, H, W]
            return (-1, shape[2], *shape[:2]), int(self._hailo_platform.FormatOrder.NCHW)
        elif len(shape) == 2:
            if layer_type == LayerType.argmax:
                # [H, W] -> [-1, 1, H, W]
                output_shape = (-1, 1, *shape)
            else:
                # [H, W] -> [-1, H, W]
                output_shape = (-1, *shape)
            return output_shape, int(self._hailo_platform.FormatOrder.NHW)
        elif len(shape) == 1:
            if layer_type in [LayerType.avgpool, LayerType.global_avg_pool]:
                self._logger.warning("GlobalAvgPool output shape might be different from ONNX shape")
            # [C] -> [-1, C]
            return (-1, *shape), int(self._hailo_platform.FormatOrder.NC)

    def _get_hailo_onnx_graph(self):
        """
        Returns ONNX graph containing inputs and outputs matching the input and outputs of the HEF and a HailoOp, custom
        ONNX node wrapping the HEF.
        """

        def _get_hailort_dtype(dtype, tensor_name):
            if dtype in self.UNSUPPORTED_DTYPES:
                raise UnsupportedHailoRuntimeException(
                    f"Unsupported data type {onnx.TensorProto.DataType.Name(dtype)} "
                    f"({dtype}) found in {tensor_name}",
                )

            if dtype in [onnx.TensorProto.UINT8, onnx.TensorProto.UINT16]:
                return dtype

            return onnx.TensorProto.FLOAT

        hef = self._hailo_platform.HEF(self._hef_data)

        values = []

        inputs = []
        input_names = []
        sorted_input_names = []
        input_quantized = []
        input_format_order = []
        input_cast_nodes = []
        input_cast_added = False
        for i, input_info in enumerate(hef.get_input_vstream_infos()):
            input_name = f"input{i}"
            layer_type = self._hn.get_layer_by_name(input_info.name).op
            input_type = self._original_model_meta["input_dtype"].get(
                input_name,
                self.LAYER_TO_DEFAULT_DTYPE[layer_type],
            )
            hailort_type = _get_hailort_dtype(input_type, input_name)
            input_shape, format_order = self._get_shape_and_format_order(input_info.shape)
            input_node = onnx.helper.make_tensor_value_info(input_name, hailort_type, input_shape)
            values.append(input_node)
            sorted_input_names.append(input_info.name)
            input_quantized.append(hailort_type != onnx.TensorProto.FLOAT)
            input_format_order.append(format_order)
            input_names.append(input_name)

            if input_type != hailort_type:
                input_cast_added = True
                cast_node_name = f"Cast_{input_name}"
                cast_node_input_name = f"{cast_node_name}_input"
                input_cast_node = onnx.helper.make_node(
                    "Cast",
                    [cast_node_input_name],
                    [input_name],
                    to=hailort_type,
                    name=cast_node_name,
                )
                input_cast_nodes.append(input_cast_node)
                cast_input_tensor_value = onnx.helper.make_tensor_value_info(
                    cast_node_input_name,
                    input_type,
                    input_shape,
                )
                inputs.append(cast_input_tensor_value)
                for k, v in self._original_model_meta["preprocess_io_map"].items():
                    if v == input_name:
                        self._original_model_meta["preprocess_io_map"][k] = cast_node_input_name
            else:
                inputs.append(input_node)

        outputs = []
        output_names = []
        sorted_output_names = []
        output_quantized = []
        output_format_order = []
        output_cast_nodes = []
        output_vstream_info_by_name = {output_info.name: output_info for output_info in hef.get_output_vstream_infos()}
        for i, output_name_from_hef in enumerate(hef.get_sorted_output_names()):
            layer_type = self._hn.get_layer_by_name(output_name_from_hef).op
            output_info = output_vstream_info_by_name[output_name_from_hef]
            output_name = f"output{i}"
            output_type = self._original_model_meta["output_dtype"].get(
                output_name,
                self.LAYER_TO_DEFAULT_DTYPE[layer_type],
            )
            hailort_type = _get_hailort_dtype(output_type, output_name)
            shape = output_info.shape
            output_shape, format_order = self._get_shape_and_format_order(shape, layer_type)
            output_node = onnx.helper.make_tensor_value_info(output_name, hailort_type, output_shape)
            values.append(output_node)
            sorted_output_names.append(output_name_from_hef)
            output_quantized.append(hailort_type != onnx.TensorProto.FLOAT)
            output_format_order.append(format_order)
            output_names.append(output_name)

            if output_type != hailort_type:
                cast_node_name = f"Cast_{output_name}"
                cast_node_output_name = f"{cast_node_name}_output"
                output_cast_node = onnx.helper.make_node(
                    "Cast",
                    [output_name],
                    [cast_node_output_name],
                    to=output_type,
                    name=cast_node_name,
                )
                output_cast_nodes.append(output_cast_node)
                cast_output_tensor_value = onnx.helper.make_tensor_value_info(
                    cast_node_output_name,
                    output_type,
                    output_shape,
                )
                outputs.append(cast_output_tensor_value)
                for k, v in self._original_model_meta["inverse_postprocess_io_map"].items():
                    if v == output_name:
                        self._original_model_meta["inverse_postprocess_io_map"][k] = cast_node_output_name
            else:
                outputs.append(output_node)

        hailo_node = onnx.helper.make_node(
            self.HAILO_OP_NAME,
            input_names,
            output_names,
            domain=self.HAILO_OP_DOMAIN,
            hef=self._hef_data,
            sorted_input_names=sorted_input_names,
            sorted_output_names=sorted_output_names,
            input_quantized=input_quantized,
            input_format_order=input_format_order,
            output_quantized=output_quantized,
            output_format_order=output_format_order,
            name=self.HAILO_OP_NAME,
        )

        hailo_graph = onnx.helper.make_graph(
            [*input_cast_nodes, hailo_node, *output_cast_nodes],
            "hailo-node",
            inputs,
            outputs,
            value_info=values,
        )

        return hailo_graph, inputs, outputs, input_cast_added

    def _compose_graphs(self, input_graph, output_graph, io_map, **kwargs):
        input_graph_outputs = [node.name for node in input_graph.output]
        output_graph_inputs = [node.name for node in output_graph.input]
        io_map = [(key, value) for key, value in io_map if key in input_graph_outputs and value in output_graph_inputs]

        try:
            graph = onnx.compose.merge_graphs(input_graph, output_graph, io_map=io_map, **kwargs)
            return graph, True
        except Exception as e:
            self._logger.debug(f"Failed to compose graphs with error: {e!s}, trying again without prefix")
            try:
                graph = onnx.compose.merge_graphs(input_graph, output_graph, io_map=io_map)
                return graph, False
            except Exception as e:
                raise UnsupportedHailoRuntimeException(
                    f"Failed composing graphs in order to generate Hailo ONNX runtime model with error: {e!s}",
                )

    def _update_shapes(self, model):
        updated_model = copy.deepcopy(model)
        initializer_names = [x.name for x in updated_model.graph.initializer]
        model_inputs = [
            input_node for input_node in updated_model.graph.input if input_node.name not in initializer_names
        ]
        for model_input in model_inputs:
            model_input.type.tensor_type.shape.dim[0].dim_value = -1

        for node_input_shape in updated_model.graph.value_info:
            if len(node_input_shape.type.tensor_type.shape.dim) == 4:
                node_input_shape.type.tensor_type.shape.dim[0].dim_value = -1
        try:
            model = infer_shapes(updated_model)
        except Exception as e:
            self._logger.warning(f"Unable to update shapes to dynamic shapes: {e!s}")

        return model

    def compose(self):
        hailo_graph, hailo_inputs, hailo_outputs, input_cast_added = self._get_hailo_onnx_graph()

        preprocess_io_map = self._original_model_meta["preprocess_io_map"]
        if self._preprocess_model:
            hailo_graph_with_preprocess, added_prefix = self._compose_graphs(
                self._preprocess_model.graph,
                hailo_graph,
                preprocess_io_map.items(),
                prefix1=f"{self.PREPROCESS_PREFIX}_",
            )
            if added_prefix:
                preprocess_io_map = {f"{self.PREPROCESS_PREFIX}_{k}": v for k, v in preprocess_io_map.items()}

            # Rename HailoOp input
            if not input_cast_added:
                for node in hailo_graph_with_preprocess.node:
                    if node.name == self.HAILO_OP_NAME:
                        node.input[0] = preprocess_io_map[node.input[0]]
                    else:
                        for i in range(len(node.output)):
                            if node.output[i] in preprocess_io_map:
                                node.output[i] = preprocess_io_map[node.output[i]]
                hailo_graph_with_preprocess.value_info.extend(hailo_inputs)
        else:
            hailo_graph_with_preprocess = hailo_graph

        if self._postprocess_model:
            postprocess_io_map = [
                (out, inp) for inp, out in self._original_model_meta["inverse_postprocess_io_map"].items()
            ]
            graph, _ = self._compose_graphs(
                hailo_graph_with_preprocess,
                self._postprocess_model.graph,
                postprocess_io_map,
                prefix2=f"{self.POSTPROCESS_PREFIX}_",
            )
            graph.value_info.extend(hailo_outputs)
        else:
            graph = hailo_graph_with_preprocess

        graph.doc_string = ""
        opset_import = self._original_model_meta["opset_imports"]
        opset_import.append([self.HAILO_OP_DOMAIN, 1])
        meta = {
            "producer_name": "hailo",
            "producer_version": get_version("hailo_sdk_client"),
            "domain": self.HAILO_OP_DOMAIN,
            "doc_string": "ONNX model generated with HailoOp node",
            "ir_version": self._original_model_meta["ir_version"],
            "opset_imports": [onnx.helper.make_opsetid(*opset_id) for opset_id in opset_import],
        }

        try:
            model = onnx.helper.make_model(graph, **meta)
            model = self._update_shapes(model)
            onnx.checker.check_model(model)
        except Exception as e:
            raise UnsupportedHailoRuntimeException(f"Failed generating Hailo ONNX runtime model with error: {e!s}")

        return model
