import copy
import json
import os
import re
import time
from pathlib import Path, PosixPath
from typing import Optional

import numpy as np
import onnx
import onnxsim
from pydantic.v1 import BaseModel, Field

from hailo_model_optimization.acceleras.utils.params_loader import load_params
from hailo_model_optimization.tools.subprocess_wrapper import BaseSubprocessFlow, subprocess_wrapper
from hailo_sdk_client.exposed_definitions import DEFAULT_NN_FRAMEWORK, Dims, NNFramework
from hailo_sdk_client.model_translator.exceptions import MisspellNodeError, UnsupportedInputFormatError
from hailo_sdk_client.model_translator.fuser.fuser import HailoNNFuser
from hailo_sdk_client.model_translator.graph_lookup import (
    FwdChainNode,
    get_all_nodes_from_possible_chains,
)
from hailo_sdk_client.model_translator.onnx_translator.onnx_translator import ONNXConverter
from hailo_sdk_client.model_translator.parsing_report import ParsingReport
from hailo_sdk_client.model_translator.tflite_translator.tflite_translator import TFLiteConverter
from hailo_sdk_client.numeric_translator import bn_to_params
from hailo_sdk_client.runner.exceptions import InvalidParserInputException
from hailo_sdk_client.sdk_backend.sdk_backend_exceptions import BackendRuntimeException
from hailo_sdk_client.tools.frameworks_inference.onnx_inference_helper import (
    run_shape_inference,
    set_model_net_input_shapes,
)
from hailo_sdk_client.tools.tf_proto_helper import detect_tf_nn_framework, suggest_other_node_names
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hailo_nn import HailoNN
from hailo_sdk_common.hailo_nn.hn_definitions import HnStage, LayerType, LayerTypes
from hailo_sdk_common.logger.logger import DeprecatedAPI, DeprecationVersion, default_logger
from hailo_sdk_common.model_params.model_params import ModelParams
from hailo_sdk_common.onnx_tools.definitions import ONNX_LARGE_MODEL_BYTE_COUNT, ONNX_LARGE_MODEL_VERTEX_COUNT
from hailo_sdk_common.onnx_tools.onnx_model_metadata_extractor import ONNXModelMetadataExtractor
from hailo_sdk_common.tools.models_translator_helper import valid_orig_name


class ParserMemento(BaseModel):
    """
    Data model for storing and retrieving the state of the parser.
    """

    hn_path: Optional[PosixPath] = Field(None, description="Path to the current Hn with the state.")
    weights: PosixPath = Field(..., description="Path to the model weights.")
    results: PosixPath = Field(..., description="Path to the output results file.")


class Parser(BaseSubprocessFlow[ParserMemento]):
    def __init__(self):
        self._logger = default_logger()
        self._return_data = {"original_model_meta": {}}

    @property
    def return_data(self):
        return self._return_data

    @subprocess_wrapper()
    def translate_tf_model(
        self,
        model_path=None,
        net_name="model",
        start_node_names=None,
        end_node_names=None,
        tensor_shapes=None,
    ):
        if tensor_shapes is not None:
            raise DeprecatedAPI("Tensor shapes flag is no longer supported for TF models.", DeprecationVersion.JUL2025)

        parsing_results = None
        start_time = time.time()
        valid_net_name = HailoNN.get_valid_input_identifier(net_name, "net_name")
        self._logger.info(f"Translation started on Tensorflow model {valid_net_name}")

        nn_framework, graph, values, node_names = detect_tf_nn_framework(model_path)
        valid_node_names = [valid_orig_name(name) for name in node_names]
        self._logger.debug(f"Restored {nn_framework.value.upper()} model {valid_net_name}")

        if isinstance(start_node_names, str):
            start_node_names = [start_node_names]

        if start_node_names is not None:
            for name in start_node_names:
                if valid_orig_name(name) not in valid_node_names:
                    suggest_other_node_names(name, node_names, "Start")

        if isinstance(end_node_names, str):
            end_node_names = [end_node_names]

        if end_node_names is not None:
            for name in end_node_names:
                if valid_orig_name(name) not in valid_node_names:
                    suggest_other_node_names(name, node_names, "End")

        self._return_data["original_model_meta"]["original_model_path"] = os.path.abspath(model_path)
        parsing_results = self.parse_model_to_hn(
            graph,
            values,
            valid_net_name,
            start_node_names,
            end_node_names,
            nn_framework,
        )
        milestone = self._format_time_milestone(start_time)
        self._logger.info(
            f"Translation completed on Tensorflow model {valid_net_name} (completion time: {milestone})",
        )

        return parsing_results

    def _handle_input_shapes_or_format(self, onnx_model, net_input_arg, start_names, arg_name="shapes"):
        initializer_names = [x.name for x in onnx_model.graph.initializer]
        model_inputs = [x.name for x in onnx_model.graph.input if x.name not in initializer_names]
        net_input_names = start_names if start_names else model_inputs
        if isinstance(net_input_arg, (list, tuple)):
            if len(net_input_names) == 1:  # List of shapes
                if len(net_input_arg) == 1 and isinstance(net_input_arg[0], (list, tuple)):
                    net_input_arg = {net_input_names[0]: net_input_arg[0]}
                else:  # Single shape
                    net_input_arg = {net_input_names[0]: net_input_arg}
            else:
                msg = f"net_input_{arg_name} must be a dictionary for multiple input networks"
                raise InvalidParserInputException(msg)

        if sorted(net_input_arg.keys()) != sorted(net_input_names):
            raise UnsupportedModelError(f"start_node_names and net_input_{arg_name} keys must contain the same names.")

        if arg_name == "shapes":
            set_model_net_input_shapes(onnx_model, net_input_arg)
        elif arg_name == "format":
            if any(not isinstance(dim, Dims) for dims_list in net_input_arg.values() for dim in dims_list):
                msg = f"Got invalid value for net_input_{arg_name}, please verify and try again."
                raise InvalidParserInputException(msg)

        return net_input_arg

    def translate_onnx_model(
        self,
        model=None,
        net_name="model",
        start_node_names=None,
        end_node_names=None,
        net_input_shapes=None,
        augmented_path=None,
        disable_shape_inference=False,
        disable_rt_metadata_extraction=False,
        net_input_format=None,
        **kwargs,
    ):
        parsing_results = None
        start_time = time.time()
        valid_net_name = HailoNN.get_valid_input_identifier(net_name, "net_name")
        self._logger.info(f"Translation started on ONNX model {valid_net_name}")

        model_loaded_from_bytes = isinstance(model, bytes)
        if not model_loaded_from_bytes:
            self._return_data["original_model_meta"]["original_model_path"] = os.path.abspath(model)

        onnx_model = (
            onnx.load_model_from_string(model) if model_loaded_from_bytes else onnx.load(model, load_external_data=True)
        )
        large_model_detected = onnx_model.ByteSize() > ONNX_LARGE_MODEL_BYTE_COUNT
        long_model_detected = len(onnx_model.graph.node) > ONNX_LARGE_MODEL_VERTEX_COUNT

        should_skip_checker = False
        if large_model_detected or long_model_detected:
            self._logger.warning(
                "Large model detected. The graph may contain either a large number of operators, "
                "or weight variables with a very large capacity.",
            )
            self._logger.warning(
                "Translation time may be a bit long, and some features may be disabled (e.g. model "
                "augmentation, retry simplified model, onnx runtime hailo model extraction, etc.).",
            )

            if model_loaded_from_bytes:
                disable_shape_inference = True
                should_skip_checker = True
                self._logger.warning(
                    "Shape inference was disabled due to model size and type (bytes). "
                    "To use shape inference please provide the model as path.",
                )

        if not should_skip_checker:
            onnx.checker.check_model(model)

        milestone = self._format_time_milestone(start_time)
        self._logger.info(f"Restored ONNX model {valid_net_name} (completion time: {milestone})")

        for node in onnx_model.graph.node:
            if not node.name:
                node.name = node.output[0]

        if isinstance(start_node_names, str):
            start_node_names = [start_node_names]

        if isinstance(end_node_names, str):
            end_node_names = [end_node_names]

        if disable_rt_metadata_extraction or large_model_detected or long_model_detected:
            self._logger.debug("ONNX metadata extractor is disabled")
            self._return_data["original_model_meta"]["extractor_disabled"] = True
        else:
            metadata_extractor = ONNXModelMetadataExtractor(copy.deepcopy(onnx_model), start_node_names, end_node_names)
            preprocess_model, postprocess_model, original_model_meta = metadata_extractor.extract()
            self._return_data["preprocess_model"] = preprocess_model
            self._return_data["postprocess_model"] = postprocess_model
            self._return_data["original_model_meta"].update(original_model_meta)

            milestone = self._format_time_milestone(start_time)
            self._logger.info(f"Extracted ONNXRuntime meta-data for Hailo model (completion time: {milestone})")

        if net_input_shapes is not None:
            net_input_shapes = self._handle_input_shapes_or_format(onnx_model, net_input_shapes, start_node_names)

        if net_input_format is not None:
            net_input_format = self._handle_input_shapes_or_format(
                onnx_model, net_input_format, start_node_names, "format"
            )

        # Using given a path if given, otherwise not saving the augmented model.
        if augmented_path and not large_model_detected:
            self._logger.info(
                f"Saving a modified model, augmented with tensors names (where applicable). New file path"
                f" is at {augmented_path}",
            )
            onnx.save_model(onnx_model, augmented_path)

        try:
            parsing_results = self._parse_onnx_model_to_hn(
                onnx_model=onnx_model,
                net_name=valid_net_name,
                start_node_names=start_node_names,
                end_node_names=end_node_names,
                net_input_shapes=net_input_shapes,
                disable_shape_inference=disable_shape_inference,
                net_input_format=net_input_format,
            )

        except Exception as e:
            irrelevant_exception = isinstance(e, (MisspellNodeError, UnsupportedInputFormatError))
            if large_model_detected or long_model_detected or irrelevant_exception:
                raise e from None

            try:
                simplified_model, is_valid = onnxsim.simplify(onnx_model, skip_fuse_bn=True)
                if not is_valid:
                    self._logger.info("Failed to simplify model")
                    raise e from None

            except Exception:
                self._logger.info(f"Unable to simplify the model: {e!s}")
                raise e from None

            for node in simplified_model.graph.node:
                if not node.name:
                    node.name = node.output[0]

            # Save the simplified model if an augmented path was given
            if augmented_path and not large_model_detected:
                simplified_path = str(augmented_path).replace(".onnx", ".sim.onnx")
                self._logger.info(
                    f"Saving a simplified model, augmented with tensors names (where applicable). New "
                    f"file path is at {simplified_path}",
                )
                onnx.save_model(simplified_model, simplified_path)

            milestone = self._format_time_milestone(start_time)
            self._logger.info(f"Simplified ONNX model for a parsing retry attempt (completion time: {milestone})")

            parsing_results = self._parse_onnx_model_to_hn(
                onnx_model=simplified_model,
                net_name=valid_net_name,
                start_node_names=start_node_names,
                end_node_names=end_node_names,
                net_input_shapes=net_input_shapes,
                disable_shape_inference=disable_shape_inference,
                net_input_format=net_input_format,
                **kwargs,
            )

        milestone = self._format_time_milestone(start_time)
        self._logger.info(f"Translation completed on ONNX model {valid_net_name} (completion time: {milestone})")

        return parsing_results

    def _parse_onnx_model_to_hn(
        self,
        onnx_model,
        net_name,
        start_node_names,
        end_node_names,
        net_input_shapes,
        disable_shape_inference,
        net_input_format,
        **kwargs,
    ):
        output_shapes = {
            value_info.name: [[x.dim_value for x in value_info.type.tensor_type.shape.dim]]
            for value_info in [*onnx_model.graph.value_info, *onnx_model.graph.input]
        }

        if disable_shape_inference:
            self._logger.debug("ONNX shape inference is disabled")
        else:
            try:
                output_shapes = run_shape_inference(onnx_model, end_node_names, net_input_shapes, output_shapes)
            except Exception as e:
                self._logger.warning(f"ONNX shape inference failed: {e!s}")

        return self.parse_model_to_hn(
            onnx_model,
            None,
            net_name,
            start_node_names,
            end_node_names,
            nn_framework=NNFramework.ONNX,
            output_shapes=output_shapes,
            net_input_format=net_input_format,
            **kwargs,
        )

    def parse_model_to_hn(
        self,
        model,
        values,
        net_name,
        start_node_names=None,
        end_node_names=None,
        nn_framework=DEFAULT_NN_FRAMEWORK,
        output_shapes=None,
        net_input_format=None,
        rename_layers_by_blocks=False,
    ):
        if nn_framework == NNFramework.TENSORFLOW_LITE:
            converter = TFLiteConverter(
                model=model,
                values=values,
                start_node_names=start_node_names,
                end_node_names=end_node_names,
            )
        elif nn_framework == NNFramework.ONNX:
            converter = ONNXConverter(
                model=model,
                values=values,
                output_shapes=output_shapes,
                start_node_names=start_node_names,
                end_node_names=end_node_names,
                net_input_format=net_input_format,
            )
        else:
            raise BackendRuntimeException(f"Unsupported NN framework {nn_framework}")

        fuser = HailoNNFuser(converter.convert_model(), net_name, converter.end_node_names)
        hailo_nn = fuser.convert_model()
        hailo_nn.validate_stage(HnStage.HN)
        hailo_nn.net_params.is_transformer = self.is_transformer(hailo_nn)
        start_node_names = {x.original_names[0]: x.name for x in hailo_nn.get_input_layers()}
        input_layers = [f"'{orig_name}': '{hn_name}'" for orig_name, hn_name in start_node_names.items()]
        output_layers = [f"'{x.original_names[-1]}'" for x in hailo_nn.get_real_output_layers()]
        self._logger.info(f"Start nodes mapped from original model: {', '.join(input_layers)}.")
        self._logger.info(f"End nodes mapped from original model: {', '.join(output_layers)}.")
        if rename_layers_by_blocks:
            self._rename_layers_by_blocks(hailo_nn)
        hn_data, params_data = hailo_nn.to_hn_npz(net_name)
        params_data = {k: v if type(v) is np.ndarray else np.array(v) for k, v in params_data.items()}
        # apply native params to calculate kernel + bias for standalone batch_norm
        bn_rescaled_params = bn_to_params.batch_norm_rescale_params(model=hailo_nn, params=ModelParams(params_data))

        self._return_data["hn_data"] = hn_data
        self._return_data["bn_rescaled_params"] = bn_rescaled_params
        self._return_data["original_model_meta"]["framework"] = str(nn_framework)
        self._return_data["original_model_meta"]["parsing_report"] = converter.get_parsing_report()
        self._return_data["original_model_meta"]["start_nodes_shapes"] = self.get_start_nodes_shapes(
            nn_framework, converter, list(start_node_names.keys())
        )
        if hailo_nn.detected_anchors:
            self._return_data["original_model_meta"]["detected_anchors"] = hailo_nn.detected_anchors

        return self._return_data

    def _rename_layers_by_blocks(self, hailo_nn):
        """
        this function renames the layers in the model by the block number they belong to.
        The block number is extracted from the original name of the layer.
        The original name in LLMA has a pattern of:
        /{model}/layers.{block_number}/.../{type}_{index}
        and the new name is created by the following pattern:
        {scope}/block{block_number}__{layer_type}{index}
        where:
        - {scope} is the scope of the layer
        - {block_number} is the block number of the layer
        - {layer_type} is the type of the layer
        - {index} is the index of the layer in the block
        """
        # creates mapping between the layer type and the next available index for this type
        type_to_available_index = {layer_type: 1 for layer_type in LayerTypes}
        # extracts the original names of all layers that are a part of some block
        original_names = [
            name for layer in hailo_nn.stable_toposort() for name in layer.original_names if "layers." in name
        ]

        if len(original_names) == 0:
            self._logger.info("No layers to rename by blocks were found")
            return hailo_nn

        # extracts the numbers of blocks from the layers' original names
        blocks_numbers = {re.findall(r"layers\.(\d+)", name)[0] for name in original_names}
        # creates a mapping between the block number and the original names of the layers that belong to this block
        block_to_original_names_in_block = {
            block_number: [name for name in original_names if re.search(rf"layers\.{block_number}/", name)]
            for block_number in blocks_numbers
        }

        # stores the mapping between the old suffix of the layer and the new suffix
        old_suffix_to_new_suffix = {}
        names_mapping = {}

        for block_number, original_names_in_block in block_to_original_names_in_block.items():
            for original_name in original_names_in_block:
                layer = hailo_nn.get_layer_by_original_name(original_name)
                layer_suffix = original_name.replace(f"layers.{block_number}", "layers.x")
                if layer_suffix in old_suffix_to_new_suffix:
                    suffix = old_suffix_to_new_suffix[layer_suffix]
                else:
                    # creates a new unique suffix
                    suffix = f"{layer.op.value}{type_to_available_index[layer.op.value]}"
                    old_suffix_to_new_suffix[layer_suffix] = suffix
                    type_to_available_index[layer.op.value] += 1
                names_mapping[layer.name] = f"{layer.scope}/block{block_number}__{suffix}"

        # propagates the new names in the model
        for layer in hailo_nn:
            # updates the names of the inputs and outputs of the layer
            for succ in hailo_nn.successors(layer):
                succ.inputs = [names_mapping.get(input_name, input_name) for input_name in succ.inputs]
            for pred in hailo_nn.predecessors(layer):
                pred.outputs = [names_mapping.get(output_name, output_name) for output_name in pred.outputs]

            # if the layer name were changed, updates the name of the layer
            if layer.name in names_mapping:
                if layer.name in hailo_nn.net_params.output_layers_order:
                    hailo_nn.net_params.output_layers_order[
                        hailo_nn.net_params.output_layers_order.index(layer.name)
                    ] = names_mapping[layer.name]
                layer.name = names_mapping[layer.name]

    def _format_time_milestone(self, start_time):
        end_time = time.time()
        hours, rem = divmod(end_time - start_time, 3600)
        minutes, seconds = divmod(rem, 60)
        return f"{int(hours):0>2}:{int(minutes):0>2}:{seconds:05.2f}"

    def save_state(self, path) -> ParserMemento:
        path = Path(path)

        # Saving hn
        hn_path = path.joinpath("model.hn")
        if "hn_data" in self._return_data:
            with hn_path.open("w") as fp:
                json.dump(self._return_data["hn_data"], fp)

        # Saving Native params
        weights_path = path.joinpath("native.npz")
        if "bn_rescaled_params" in self._return_data:
            np.savez_compressed(weights_path, **(self._return_data["bn_rescaled_params"]))

        # Saving Results
        results_path = path.joinpath("results.json")
        if "parsing_report" in self._return_data["original_model_meta"]:
            with results_path.open("w") as fp:
                # pydantic object isn't serializable
                self._return_data["original_model_meta"]["parsing_report"] = self._return_data["original_model_meta"][
                    "parsing_report"
                ].dict()
                json.dump(self._return_data["original_model_meta"], fp)

        memento = ParserMemento(hn_path=hn_path, weights=weights_path, results=results_path)
        return memento

    def load_state(self, memento: ParserMemento, tf_safe: bool = False):
        if memento.hn_path.exists():
            with memento.hn_path.open("r") as fp:
                self._return_data["hn_data"] = json.load(fp)
        if memento.weights.exists():
            self._return_data["bn_rescaled_params"] = load_params(memento.weights)

        if memento.results.exists():
            with memento.results.open("r") as fp:
                self._return_data["original_model_meta"] = json.load(fp)
                # return to pydantic object
                dict_obj = self._return_data["original_model_meta"]["parsing_report"]
                self._return_data["original_model_meta"]["parsing_report"] = ParsingReport(**dict_obj)

    def build_model(self):
        pass

    def reset_subprocess(self):
        pass

    def is_transformer(self, hailo_nn):
        """
        Search for known transformer kqv patterns in the graph
        """
        possible_kqv_chains = [
            [FwdChainNode(op=LayerType.softmax), FwdChainNode(op=LayerType.matmul)],
            [
                FwdChainNode(op=LayerType.normalization),
                FwdChainNode(op=LayerType.softmax),
                FwdChainNode(op=LayerType.matmul),
            ],
            [
                FwdChainNode(op=LayerType.normalization),
                FwdChainNode(op=LayerType.ew_add),
                FwdChainNode(op=LayerType.softmax),
                FwdChainNode(op=LayerType.matmul),
            ],
            [
                FwdChainNode(op=LayerType.normalization),
                FwdChainNode(op=LayerType.ew_add),
                FwdChainNode(op=LayerType.ew_add),
                FwdChainNode(op=LayerType.softmax),
                FwdChainNode(op=LayerType.matmul),
            ],
        ]
        for layer in hailo_nn.nodes():
            if layer.op == LayerType.matmul:
                consumed_vertices = get_all_nodes_from_possible_chains(
                    hailo_nn, layer, possible_kqv_chains, exact_match=True
                )
                if consumed_vertices:
                    return True
        return False

    def get_start_nodes_shapes(self, framework, converter, start_node_names):
        """
        Get the original shapes of the start nodes in the model, using get_output_shapes with convert_to_nhwc=False for ONNX.
        """
        orig_start_node_names = []
        # use the original layers names, even for models with invalid names (containing ";")
        for node in converter.graph.nodes_toposorted():
            if not start_node_names:
                break
            valid_name = valid_orig_name(node.name)
            if valid_name in start_node_names:
                orig_start_node_names.append(node.name)
                start_node_names.remove(valid_name)

        return {
            name: converter.graph.get_vertex_by_name(name).get_output_shapes(
                convert_to_nhwc=framework != NNFramework.ONNX
            )[0]
            for name in orig_start_node_names
        }
