#!/usr/bin/env python

"""Hailo DFC API client."""

import json
import os
import pathlib
import tempfile
from collections import OrderedDict
from contextlib import contextmanager
from typing import Generator, Optional

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DEFAULT_BOX_AND_OBJ_PXLS,
    DistributionStrategy,
    FormatConversionType,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.dataset_util import DatasetContianer
from hailo_model_optimization.acceleras.utils.flow_state.updater import FlowCommands
from hailo_model_optimization.acceleras.utils.params_loader import ParamSerializationType, load_params, save_params
from hailo_model_optimization.flows.inference_flow import InferenceModel, SimulationTrainingModel
from hailo_model_optimization.flows.optimization_flow import SupportedStops
from hailo_model_optimization.tools.orchestator import FlowCheckPoint
from hailo_sdk_client.exposed_definitions import (
    DEFAULT_HW_ARCH,
    NON_SUPPORTED_HW_ARCHS,
    PARTIALLY_SUPPORTED_HW_ARCHS,
    SUPPORTED_HW_ARCHS,
    CalibrationDataType,
    ContextInfo,
    InferenceContext,
    InferenceDataType,
    JoinAction,
    NNFramework,
    States,
)
from hailo_sdk_client.hailo_archive.hailo_archive import HailoArchive
from hailo_sdk_client.hw_consts.hw_arch import HWArch
from hailo_sdk_client.model_translator.parsing_report import ParsingReport
from hailo_sdk_client.paths_manager.platform_importer import get_platform
from hailo_sdk_client.quantization.quantize import data_to_dataset
from hailo_sdk_client.runner import utils as runner_utils
from hailo_sdk_client.runner.exceptions import (
    HailoPlatformMissingException,
    HNNotSetException,
    InvalidArgumentsException,
    UnsupportedCustomSessionException,
    UnsupportedRunnerJoinException,
    UnsupportedTargetException,
)
from hailo_sdk_client.sdk_backend.parser.parser import Parser
from hailo_sdk_client.sdk_backend.script_parser.nms_postprocess_command import ANCHORLESS_YOLOS, NMSPostprocessCommand
from hailo_sdk_client.sdk_backend.sdk_backend import SDKBackend
from hailo_sdk_client.sdk_backend.sdk_backend_data_class import CheckpointInfo, InferInfo, InternalContextInfo
from hailo_sdk_client.tools.core_postprocess.nms_postprocess import NMSConfig, NMSMetaData
from hailo_sdk_client.tools.network_join import get_runner_scopes, handle_model_script_join, merge_hns, merge_params
from hailo_sdk_common.compatibility import file_types, string_types
from hailo_sdk_common.export.hailo_graph_export import ExportLevel, GraphExport, HailoGraphExport, OutputTensorsExport
from hailo_sdk_common.hailo_nn.exceptions import HailoNNException
from hailo_sdk_common.hailo_nn.hailo_nn import HailoNN
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType, LayerType, NMSMetaArchitectures
from hailo_sdk_common.hailo_nn.nms_postprocess_defaults import DEFAULT_YOLO_ANCHORS
from hailo_sdk_common.logger.logger import DeprecationVersion, default_logger
from hailo_sdk_common.model_params.model_params import ModelParams
from hailo_sdk_common.onnx_tools.hailo_onnx_model_composer import (
    HailoONNXModelComposer,
    UnsupportedHailoRuntimeException,
)
from hailo_sdk_common.paths_manager.paths import SDKPaths
from hailo_sdk_common.paths_manager.SimWrapper import HSimWrapper
from hailo_sdk_common.states.states import InvalidStateException, allowed_states
from hailo_sdk_common.targets.infer_wrapper import HefInferWrapper
from hailo_sdk_common.targets.inference_targets import EmulationObject, ParamsKinds

SLIM_STATES = [States.QUANTIZED_SLIM_MODEL, States.COMPILED_SLIM_MODEL]
INITIALIZED_STATES = [x for x in States if x != States.UNINITIALIZED]
HN_STATES = [x for x in States if x not in [States.UNINITIALIZED, States.ORIGINAL_MODEL]]
QUANTIZED_STATES = [States.QUANTIZED_MODEL, States.QUANTIZED_SLIM_MODEL]


class ClientRunner:
    """Hailo DFC API client."""

    def __init__(self, hn=None, hw_arch=None, har=None):
        """
        DFC client constructor

        Args:
            hn: Hailo network description (HN), as a file-like object, string, dict, or
                :class:`~hailo_sdk_common.hailo_nn.hailo_nn.HailoNN`. Use None if you intend to
                parse the network description from Tensorflow later.
                Notice: This flag will be deprecated soon.
            hw_arch (str, optional): Hardware architecture to be used. Defaults to ``hailo8``.
            har (str or :class:`~hailo_sdk_common.hailo_archive.hailo_archive.HailoArchive`, optional): Hailo
                Archive file path or Hailo Archive object to initialize the runner from.

        """
        SDKPaths().set_client_build_dir_path()
        # Reset custom build dir
        SDKPaths().custom_build_dir = None

        self._logger = default_logger()

        if hn is not None:
            # TODO: https://hailotech.atlassian.net/browse/SDK-46364
            self._logger.deprecation_warning(
                "The hn flag will be deprecated soon, please use `har` instead.",
                DeprecationVersion.APR2024,
            )
        if hw_arch is not None and hw_arch in NON_SUPPORTED_HW_ARCHS:
            raise InvalidArgumentsException(
                f"{hw_arch} is not a valid hw arch. Please use Dataflow Compiler v5.x instead",
            )

        if hw_arch is not None and hw_arch not in SUPPORTED_HW_ARCHS + PARTIALLY_SUPPORTED_HW_ARCHS:
            raise InvalidArgumentsException(
                f'{hw_arch} is not a valid hw arch. Please choose one of: {", ".join(SUPPORTED_HW_ARCHS)}',
            )

        self._hw_arch = HWArch.get_real_hw_arch(hw_arch)
        self._state = States.UNINITIALIZED
        self._original_model_path = None
        self._auto_model_script = None
        self._temp_dir = None
        self._hef = None
        self._preprocess_model = None
        self._postprocess_model = None
        self._original_model_meta = {}

        # Waiting for hn
        self._mid, self._model_name = (None,) * 2
        self._sdk_backend = None

        # Waiting for params
        HSimWrapper().load()
        self._cached_model = None
        self._sub_models = None
        self._number_of_sub_models = 0

        if har is not None and hn is not None:
            raise InvalidArgumentsException("Client Runner can get either HN or HAR but not both.")

        if har is not None:
            self.load_har(har=har)

        if hn is not None:
            self.set_hn(hn)
        self._use_service = False
        self._hailo_platform = get_platform()

    def __enter__(self):
        pass

    def __exit__(self, type, value, traceback):
        pass

    def __del__(self):
        if self._temp_dir is not None:
            self._temp_dir.cleanup()

    def _generate(self, hn):
        hn = self._load_hn(hn)
        hn_dict = hn.to_hn(hn.name, json_dump=False)
        hw_arch = HWArch(self._hw_arch)
        return SDKBackend(hn=hn_dict, hw_arch=hw_arch)

    @property
    def _hn(self):
        if self._sdk_backend is None:
            return None
        return self._sdk_backend.model

    @property
    def model_script(self):
        if self._sdk_backend is None:
            return None

        return self._sdk_backend.model_script

    @property
    def _force_weightless_model(self):
        if self._sdk_backend is None:
            return False

        return self._sdk_backend.force_weightless_model

    @property
    def modifications_meta_data(self):
        if self._sdk_backend is None:
            return {}
        return self._sdk_backend.modifications_meta_data

    @allowed_states(States.HAILO_MODEL)
    def force_weightless_model(self, weightless=True):
        """
        DFC API to force the model to work in weightless mode.

        When this mode is enabled, the software emulation graph can be received
        from :func:`~hailo_sdk_client.runner.client_runner.ClientRunner.get_tf_graph`
        even when the parameters are not loaded.

        Note:
            This graph cannot be used for running inference unless the model does not require
            weights.

        Args:
            weightless (bool): Set to True to enable weightless mode. Defaults to True.

        """
        self._sdk_backend.force_weightless_model = weightless

        return weightless

    @allowed_states(States.QUANTIZED_MODEL, States.QUANTIZED_SLIM_MODEL)
    def set_keras_model(self, model: SimulationTrainingModel):
        """
        Set Keras model after quantization-aware training.
        This method allows you to set the model after editing it externally.
        After setting the model, new quantized weights are generated.

        Args:
            model (SimulationTrainingModel): model to set.

        """
        if model.is_trainable and isinstance(model, SimulationTrainingModel):
            slim_mode = self.state == States.QUANTIZED_SLIM_MODEL
            self._sdk_backend.set_emulation_model(model, slim_mode=slim_mode)
        else:
            raise InvalidArgumentsException("Set can only be used with trainable model.")

    @allowed_states(
        States.HAILO_MODEL,
        States.FP_OPTIMIZED_MODEL,
        States.QUANTIZED_MODEL,
        States.QUANTIZED_BASE_MODEL,
        States.COMPILED_MODEL,
    )
    def get_keras_model(self, context: ContextInfo, trainable=False) -> InferenceModel:
        """Get a Keras model for inference.
        This method returns a model for inference in either native, fp-optimized, quantized, or HW mode.
        Editing the keras model won't affect quantization/compilation unless
        :func:`~hailo_sdk_client.runner.client_runner.ClientRunner.set_keras_model` API is being used.

        Args:
            context (:class:`~hailo_sdk_client.exposed_definitions.ContextInfo`):
                inference context generated by infer_context.
            trainable (bool, optional):
                indicate whether the returned model should be trainable or not.
                :func:`~hailo_sdk_client.runner.client_runner.ClientRunner.set_keras_model` only supports trainable
                models.

        Example:
            >>> with runner.infer_context(InferenceContext.SDK_NATIVE) as ctx:
            >>>     result = runner.get_keras_model(ctx)

        """
        infer_context = context.infer_context
        if not context.open:
            raise UnsupportedTargetException(
                f"Unsupported context for get_keras_model {infer_context}. "
                f"Are you running infer inside inference context? "
                f"(`with runner.infer_context():`)",
            )

        if infer_context in [
            InferenceContext.SDK_NATIVE,
            InferenceContext.SDK_FP_OPTIMIZED,
            InferenceContext.SDK_QUANTIZED,
            InferenceContext.SDK_BIT_EXACT,
        ]:
            if trainable and infer_context != InferenceContext.SDK_QUANTIZED:
                raise NotImplementedError(
                    f"Context {infer_context} does not support trainable model. Please use SDK_QUANTIZED instead.",
                )
            return self._sdk_backend.get_emulation_model(context, trainable)

        elif infer_context == InferenceContext.SDK_HAILO_HW:
            if not self._hailo_platform:
                raise HailoPlatformMissingException(
                    "The HailoRT Python API (hailo_platform package) "
                    "must be installed in order to use the Hailo hardware",
                )
            if trainable:
                raise NotImplementedError(
                    f"Context {infer_context} does not support trainable model. Please use SDK_QUANTIZED instead.",
                )
            return self._sdk_backend.get_hw_model(context.graph_export)

        else:
            raise UnsupportedTargetException(
                f"Unsupported context for infer {infer_context}. "
                f"Are you running infer inside inference context? "
                f"(`with runner.infer_context():`)",
            )

    @allowed_states(
        States.HAILO_MODEL,
        States.FP_OPTIMIZED_MODEL,
        States.QUANTIZED_MODEL,
        States.QUANTIZED_BASE_MODEL,
        States.COMPILED_MODEL,
        States.COMPILED_SLIM_MODEL,
    )
    def infer(
        self,
        context: ContextInfo,
        dataset,
        data_type=InferenceDataType.auto,
        data_count: Optional[int] = None,
        batch_size: int = 8,
    ) -> list:
        """DFC API for inference.
        This method infers the given dataset on the model in either full-precision, emulation (quantized),
        or HW and returns the output.

        Args:
            context (:class:`~hailo_sdk_client.exposed_definitions.ContextInfo`): inference context generated by
                infer_context
            dataset: data for Inference. The type depends on the ``data_type`` parameter.
            data_type (:class:`~hailo_sdk_client.exposed_definitions.InferenceDataType`):
                dataset's data type, based on enum values:

                * :attr:`~hailo_sdk_client.exposed_definitions.InferenceDataType.auto` -- Automatically detection.
                * :attr:`~hailo_sdk_client.exposed_definitions.InferenceDataType.np_array` -- ``numpy.ndarray``,
                  or dictionary with input layer names as keys, and values types of ``numpy.ndarray``.
                * :attr:`~hailo_sdk_client.exposed_definitions.InferenceDataType.dataset` -- ``tensorflow.data.Dataset``
                  object with a valid signature. signature should be either ((h, w, c), image_info) or
                  ({'input_layer1': (h1, w1, c1), 'input_layer2': (h2, w2, c2)}, image_info) image_info can
                  be an empty dict for inference
                * :attr:`~hailo_sdk_client.exposed_definitions.InferenceDataType.npy_file` -- path to a npy or npz file
                * :attr:`~hailo_sdk_client.exposed_definitions.InferenceDataType.npy_dir` -- path to a npy or npz dir,
                  assumes the same shape to all the items

            data_count (int): optional argument to limit the number of elements for inference
            batch_size (int): batch size for inference

        Returns:
            :list: list of outputs. Entry i in the list is the output of input i.
            In case the model contains more than one output, each entry is a list of all the outputs.

        Example:
            >>> with runner.infer_context(InferenceContext.SDK_NATIVE) as ctx:
            >>>     result = runner.infer(
            ...         ctx,
            ...         dataset=tf.data.Dataset.from_tensor_slices(np.ones((1, 10))),
            ...         batch_size=1
            ...     )

        """
        if not context.open:
            raise UnsupportedTargetException(
                f"Unsupported context for infer {context.infer_context}. "
                f"Are you running infer inside inference context? "
                f"(`with runner.infer_context():`)",
            )

        if self.state == States.COMPILED_SLIM_MODEL and context.infer_context != InferenceContext.SDK_HAILO_HW:
            raise UnsupportedTargetException(
                f"Only hardware inference is allowed in compiled slim state (got {self.state})."
            )
        elif (
            self.state == States.QUANTIZED_BASE_MODEL
            and context.infer_context not in [InferenceContext.SDK_NATIVE, InferenceContext.SDK_FP_OPTIMIZED]
            and context.lora_adapter_name != self._sdk_backend.get_lora_adapters()[0]
        ):
            raise UnsupportedTargetException(
                f"Only native and fp-optimized inference contexts are allowed in partially quantized"
                f"state (got {self.state}). To run quantized inference context, first use the optimize "
                f"API to quantize the recently added LoRA adapter."
            )

        if isinstance(dataset, dict):
            dataset = {self._hn.get_layer_by_name(key).name: value for key, value in dataset.items()}
        data, _ = data_to_dataset(dataset, data_type, self._logger)

        infer_info = InferInfo(context_info=context, data=data, batch_size=batch_size, data_count=data_count)

        if context.infer_context in [
            InferenceContext.SDK_NATIVE,
            InferenceContext.SDK_FP_OPTIMIZED,
            InferenceContext.SDK_QUANTIZED,
            InferenceContext.SDK_BIT_EXACT,
        ]:
            return self._infer_emulator(infer_info)

        elif context.infer_context == InferenceContext.SDK_HAILO_HW:
            if not self._hailo_platform:
                raise HailoPlatformMissingException(
                    "The HailoRT Python API (hailo_platform package) "
                    "must be installed in order to use the Hailo hardware",
                )
            return self._infer_hw(infer_info)

        else:
            raise UnsupportedTargetException(
                f"Unsupported context for infer {context.infer_context}. "
                f"Are you running infer inside inference context? "
                f"(`with runner.infer_context():`)",
            )

    def _infer_emulator(self, infer_info: InferInfo):
        requires_quantized_weights = self._hn.requires_quantized_weights and not self.force_weightless_model
        if (
            infer_info.context_info.infer_context in [InferenceContext.SDK_QUANTIZED, InferenceContext.SDK_BIT_EXACT]
            and requires_quantized_weights
            and self._state not in [States.QUANTIZED_MODEL, States.COMPILED_MODEL]
        ):
            raise InvalidStateException(
                f"Infer in context {infer_info.context_info.infer_context} is invalid in {self._state.value} "
                f"state. Please run 'optimize' to obtain quantized weights and try again.",
            )
        return self._sdk_backend.acceleras_inference(infer_info)

    def _get_keras_input_nodes(self):
        nodes = {}
        input_layers_names = [layer.name for layer in self._hn.get_input_layers()]
        with tf.Graph().as_default():
            for layer_name in input_layers_names:
                hn_layer = self._hn.get_layer_by_name(layer_name)
                shape = hn_layer.output_shape[1:]
                nodes[layer_name] = tf.keras.Input(dtype=tf.float32, shape=shape)
        return nodes

    # remove this function when async api on hailo8 is enabled SDK-51150
    def _get_network_groups_for_inference(self, target, context):
        if self._hef is None:
            if self._state in [States.HAILO_MODEL, States.FP_OPTIMIZED_MODEL]:
                raise InvalidStateException(
                    f"Infer with target {context.name} is invalid in {self._state.value} "
                    f"state. Please run `optimize` and `compile` to obtain a compiled model.",
                )
            elif self._state in QUANTIZED_STATES:
                raise InvalidStateException(
                    f"Infer with target {context.name} is invalid in {self._state.value} "
                    f"state. Please run `compile` to obtain a compiled model.",
                )
            else:
                raise InvalidStateException(
                    f"Infer with target {context.name} is invalid in {self._state.value} state.",
                )
        self._logger.debug("Configuring HEF to Hailo HW")
        #### NOTE: moving this here as WA for issue with loading HEF from string instead of file
        # revert when https://hailotech.atlassian.net/browse/HRT-14006 is done
        if isinstance(self._hef, str) and os.path.exists(self._hef):
            hef_path = self._hef
        else:
            hef_path = os.path.join(os.getcwd(), f"{self.model_name}.hef")
            with open(hef_path, "wb") as f:
                f.write(self._hef)
        ####
        hef = self._hailo_platform.HEF(hef_path)
        return target.configure(hef)

    def _infer_hw(self, infer_info: InferInfo):
        return self._sdk_backend.hw_inference(infer_info)

    @allowed_states(*HN_STATES)
    def load_model_script(self, model_script=None, append=False):
        """
        DFC API for manipulation of the model build params. This method loads a script and
        applies it to the existing HN, i.e., modifies the specific params in each layer, and sets
        the model build script for later use.

        Args:
            model_script (str, pathlib.Path): A model script is given as either a path to the ALLS file or commands as
                a string allowing the modification of the current model, before quantization / native emulation /
                profiling, etc.
                The SDK parses the script, and applies the commands as follows:

                1. Model modification related commands -- These commands are executed during
                   optimization.
                2. Quantization related commands -- Some of these commands modify the HN, so after
                   the modification, each layer (possibly) has new quantization parameters. Other
                   commands are executed during optimization.
                3. Allocation and compilation related commands -- These commands are executed during
                   compilation.

            append (boolean): Whether to append the commands to a previous script (if exists) or use only the new
                script. Addition is allowed only in native mode. Defaults to False.

        Returns:
            dict: A copy of the new modified HN (JSON dictionary).

        """
        if append and self.state != States.HAILO_MODEL:
            max_performance_cmd = "performance_param(compiler_optimization_level=max)"
            if not (self.state in QUANTIZED_STATES and model_script == max_performance_cmd):
                raise InvalidStateException(
                    f"Append mode is supported only on the native model, but got {self.state} instead."
                )
        action_string = "Appending" if append else "Loading"
        is_valid_model_script_arg, err_info = runner_utils.is_string_model_script_or_path(model_script, self._hn)
        if not is_valid_model_script_arg:
            raise InvalidArgumentsException(f"either model script is illegal or file path doesn't exist: {err_info}")

        if isinstance(model_script, pathlib.PurePath) or os.path.isfile(model_script):
            self._logger.info(f"{action_string} model script commands to {self._model_name} from {model_script}")
            self._sdk_backend.load_model_script_from_file(model_script, append)
        else:
            self._logger.info(f"{action_string} model script commands to {self._model_name} from string")
            self._sdk_backend.load_model_script(model_script, append)

        self._hef = None

        return self.get_hn()

    @allowed_states(*HN_STATES)
    def load_params(self, params, params_kind=None):
        """
        Load network params (weights).

        Args:
            params: If a string, this is treated as the path of the npz file to load. If a dict,
                this is treated as the params themselves, where the keys are strings and the values
                are numpy arrays.
            params_kind (str, optional): Indicates whether the params to be loaded are native, native
                after BN fusion, or quantized.

        Returns:
            str: Kind of params that were actually loaded.

        """
        orig_state = self.state
        if isinstance(params, (string_types, file_types, pathlib.Path)):
            params_to_use = ModelParams(load_params(params, type_=".npz"))
        elif isinstance(params, ModelParams):
            params_to_use = params
        elif isinstance(params, dict):
            params_to_use = ModelParams(params)
        else:
            raise ValueError("params must be one of the following: string (file path), file, dict or ModelParams")

        self._validate_slim_params_kind(params_to_use, params_kind)
        loaded_params_kind = self._sdk_backend.load_params(params_to_use, params_kind)

        if params_to_use is not None:
            if loaded_params_kind is ParamsKinds.FP_OPTIMIZED:
                self._state = States.FP_OPTIMIZED_MODEL
            elif loaded_params_kind is ParamsKinds.TRANSLATED:
                self._state = States.QUANTIZED_SLIM_MODEL if orig_state in SLIM_STATES else States.QUANTIZED_MODEL

        self._hef = None

        return ParamsKinds.NATIVE if loaded_params_kind is None else loaded_params_kind

    def _validate_slim_params_kind(self, params: ModelParams, params_kind):
        if self.state not in SLIM_STATES:
            return

        if not params_kind:
            if params.params_kind_enum:
                params_kind = params.params_kind_enum
            else:
                params_kind = params._params.get(ModelParams.PARAMS_KIND_STR)

        if not params_kind or params_kind != ParamsKinds.TRANSLATED:
            raise InvalidStateException(
                f"Loading params in slim mode is allowed only for quantized params, got {params_kind}."
            )

    @allowed_states(*HN_STATES)
    def save_params(self, path, params_kind=ParamsKinds.NATIVE):
        """
        Save all model params to a npz file.

        Args:
            path (str): Path of the npz file to save.
            params_kind (str, optional): Indicates whether the params to be saved are native, native
                after BN fusion, or quantized.

        """
        if self.state in SLIM_STATES and params_kind != ParamsKinds.TRANSLATED:
            raise InvalidStateException(f"save_params with {params_kind} kind is invalid under {self._state} state")

        if params_kind == ParamsKinds.TRANSLATED:
            if self._state in [States.HAILO_MODEL, States.FP_OPTIMIZED_MODEL]:
                raise InvalidStateException(
                    f"save_params with {params_kind} kind is invalid under {self._state.value} state",
                )
            params = self._sdk_backend.get_params_translated()
        elif params_kind in (ParamsKinds.NATIVE_FUSED_BN, ParamsKinds.NATIVE):
            params = self._sdk_backend.get_params()
        elif params_kind == ParamsKinds.FP_OPTIMIZED:
            params = self._sdk_backend.get_params_fp_optimized()
        elif params_kind == ParamsKinds.STATISTICS:
            params = self._sdk_backend.get_params_statistics()
        else:
            raise ValueError(f"Unexpected params_kind {params_kind}")
        type(self)._save_params(params, path)
        self._logger.info(f"Saved params from {self._model_name} at {path}")

    @allowed_states(*HN_STATES)
    def _get_previous_hailo_export(self):
        """Get the last Hailo export returned to the user."""
        return self._cached_model

    def _get_sw_tf_graph(
        self,
        target,
        nodes,
        rescale_output,
        translate_input,
        custom_session=None,
        twin_mode=False,
        native_layers=None,
        run_numeric_in_int32=False,
    ):
        requires_quantized_weights = self._hn.requires_quantized_weights and not self.force_weightless_model
        if (
            self._is_numeric(target)
            and requires_quantized_weights
            and self._state in [States.HAILO_MODEL, States.FP_OPTIMIZED_MODEL]
        ):
            raise InvalidStateException(
                f"get_tf_graph with target {target.name} with model that requires quantized weights is "
                f"invalid with {self._state.value} state",
            )

        self._logger.debug("Preparing SW model")

        hailo_export, sub_graph_export, number_of_sub_graphs = self._sdk_backend._get_tf_graph(
            target=target,
            nodes=nodes,
            rescale_output=rescale_output,
            translate_input=translate_input,
            twin_mode=twin_mode,
            native_layers=native_layers,
            run_numeric_in_int32=run_numeric_in_int32,
            custom_session=custom_session,
        )

        self._cached_model = hailo_export
        self._sub_models = sub_graph_export
        self._number_of_sub_models = number_of_sub_graphs
        self._logger.info("SW TF graph ready")
        return hailo_export

    @staticmethod
    def _get_infer_wrapper(nodes, target, rescale_output=True, translate_input=True, network_groups=None):
        input_names = list(nodes.keys())
        if network_groups is not None:
            output_names = network_groups[0].get_sorted_output_names()
        else:
            output_names = target.sorted_output_layer_names
        return HefInferWrapper(
            infer_model=None,  # remove this two lines when async api on hailo8 is enabled SDK-51150
            device=None,
            input_names=input_names,
            translate_input=translate_input,
            rescale_output=rescale_output,
            output_names=output_names,
            network_groups=network_groups,
        )

    def _get_hw_tf_graph(
        self,
        target,
        nodes,
        rescale_output,
        translate_input,
        custom_session=None,
        twin_mode=False,
        fps=4,
        use_preloaded_compilation=False,
        mapping_timeout=None,
        allocator_script_filename=None,
        network_groups=None,
    ):
        node = next(iter(nodes.values()))
        if twin_mode:
            raise UnsupportedTargetException("Twin mode is not supported when running on HW")

        if (not target.hef_loaded()) or (not use_preloaded_compilation):
            if not self._hef:
                self._logger.debug("Preparing HW model")
                self._compile(
                    fps=fps,
                    mapping_timeout=mapping_timeout,
                    allocator_script_filename=allocator_script_filename,
                )
            hef = self._hailo_platform.HEF(self._hef)
            network_groups = target.configure(hef)
        elif network_groups is None:
            network_groups = target.loaded_network_groups

        self._logger.debug("Loading HW TF graph")

        session = custom_session if custom_session is not None else tf.compat.v1.Session(graph=node.graph)
        infer_wrapper = self._get_infer_wrapper(nodes, target, rescale_output, translate_input)

        input_tensors = list(nodes.values())
        with node.graph.as_default(), session.as_default():
            dtypes = [(tf.float32 if rescale_output else dtype) for dtype in self._hn.get_output_dtypes()]
            out = tf.numpy_function(infer_wrapper.tf_infer, input_tensors, dtypes, name="infer_hw_py_func")
            for output_tensor, output_shape in zip(out, self._get_target_output_shapes(target)):
                output_tensor.set_shape([None, *list(output_shape)])
        tensors_export = {
            ExportLevel.OUTPUT_LAYERS: OutputTensorsExport(
                export_level=ExportLevel.OUTPUT_LAYERS,
                tensors=out,
                layers_names=target.sorted_output_layer_names,
            ),
        }

        hailo_export = HailoGraphExport(
            session=session,
            graph=node.graph,
            input_tensors=nodes,
            init_output_exports=tensors_export,
            hef=self._hef,
            network_groups=network_groups,
            hef_infer_wrapper=infer_wrapper,
        )
        try:
            hailo_export.update_original_names(self._hn_dict)
        except HailoNNException:
            # This might happen in some cases such as loopback (IB->OB)
            self._logger.warning(
                "Unable to connect between hardware output tensors to their original names in the Tensorflow model",
            )

        self._cached_model = hailo_export
        self._logger.info("HW TF graph ready")
        return hailo_export

    def _get_target_output_shapes(self, target):
        return target.get_output_shapes()

    @allowed_states(*HN_STATES)
    def _get_hef_tf_graph(
        self,
        target,
        nodes=None,
        translate_input=None,
        rescale_output=None,
        custom_session=None,
        twin_mode=False,
        native_layers=None,
        fps=4,
        use_preloaded_compilation=False,
        mapping_timeout=None,
        allocator_script_filename=None,
        network_groups=None,
        run_numeric_in_int32=False,
    ):
        if translate_input is None:
            translate_input = self._is_numeric(target)
        if rescale_output is None:
            rescale_output = self._is_numeric(target)

        input_layers_names = [layer.name for layer in self._hn.get_input_layers()]
        if not isinstance(nodes, dict) and nodes is not None:
            nodes = {input_layers_names[0]: nodes}

        # handle a case where preprocess isn't provided
        # the dtype of the placeholder defaults to float32, and will be set to uint8 only if the
        # target is hardware, and translate_input is set to false
        if (nodes is None) and (custom_session is not None):
            raise UnsupportedCustomSessionException("No preprocessing graph was given with custom session.")

        # TODO: target is not an enum in the case of HW. Maybe separate SW and HW APIs?
        if isinstance(target, EmulationObject):
            return self._get_sw_tf_graph(
                target,
                nodes,
                rescale_output,
                translate_input,
                custom_session=custom_session,
                twin_mode=twin_mode,
                native_layers=native_layers,
                run_numeric_in_int32=run_numeric_in_int32,
            )

        if not self._hailo_platform:
            raise HailoPlatformMissingException(
                "The HailoRT Python API (hailo_platform package) must be installed in "
                "order to use the Hailo hardware",
            )
        if nodes is None:
            nodes = {}
            with tf.Graph().as_default():
                for layer_name in input_layers_names:
                    hn_layer = self._hn.get_layer_by_name(layer_name)
                    if hn_layer.precision_config.precision_mode in [
                        PrecisionMode.a16_w16,
                        PrecisionMode.a16_w16_a16,
                        PrecisionMode.a16_w16_a8,
                    ]:
                        dtype = tf.uint16
                    else:
                        dtype = tf.uint8
                    dtype = dtype if not translate_input else tf.float32
                    shape = hn_layer.output_shape[1:]
                    nodes[layer_name] = tf.keras.Input(dtype=dtype, name="default_placeholder", shape=shape)
        else:
            nodes = {self._hn.get_layer_by_name(layer).name: node for layer, node in nodes.items()}
        nodes = OrderedDict(nodes)
        if set(input_layers_names) != set(nodes):
            raise InvalidArgumentsException(
                "Unexpected mismatch between model input tensors and feed input tensors names.",
            )
        return self._get_hw_tf_graph(
            target,
            nodes,
            rescale_output,
            translate_input,
            custom_session=custom_session,
            twin_mode=twin_mode,
            fps=fps,
            use_preloaded_compilation=use_preloaded_compilation,
            mapping_timeout=mapping_timeout,
            allocator_script_filename=allocator_script_filename,
            network_groups=network_groups,
        )

    def _get_tf_graph(
        self,
        target,
        nodes=None,
        translate_input=None,
        rescale_output=None,
        custom_session=None,
        twin_mode=False,
        native_layers=None,
        fps=None,
        use_preloaded_compilation=False,
        mapping_timeout=None,
        allocator_script_filename=None,
        network_groups=None,
    ):
        """
        DFC API for getting Tensorflow graph of a current model. This function is for internal use only

        Args:
            target: One of the hardware targets
                (:class:`~hailo_platform.pyhailort.hw_object.HailoHWObject`) or one of the emulation
                targets (:class:`~hailo_sdk_common.targets.inference_targets.EmulationObject`).
            nodes (dict): Input layer names mapped to last Tensorflow nodes of the pre-processing
                stage. These nodes represent the inputs of the SDK. Layers are asserted to be the HN
                input layers' names. If there is only one input to the graph, nodes also accept a
                Tensor.
            translate_input (bool, optional): Set to True if the input is in a native scale and has
                to be translated to uint8. Usually, True in numeric and hardware targets. Defaults
                to None, which sets it to True for numeric and hardware targets and False for native
                targets.
            rescale_output (bool, optional): Set to True to rescale the results from uint8 to their
                native scale. Typically, True in numeric and hardware targets. Defaults to None, which
                sets it to True for numeric and hardware targets and False for native targets.
            custom_session (:obj:`tf.Session`, optional): Tensorflow session in which the returned
                graph will be loaded. Defaults to None, which means the SDK will create a new
                session.
            twin_mode (bool, optional): When you want to emulate the same model twice inside
                the same Tensorflow graph, set it to True in the second call to avoid tensor name
                conflicts. For instance, this is useful to emulate both native and numeric targets
                together. Defaults to False.
            native_layers (list of str, optional): When using ``SdkMixed`` target, use this param to
                specify the names of all layers to keep native. All other layers will be emulated in
                numeric mode. Defaults to None.
            fps (float, optional): Allocation FPS. If None, the compilation process will automatically
                try to reach max throughput (max FPS). Defaults to None.
            use_preloaded_compilation (bool, optional): Use the HEF loaded to the
                :class:`HailoHWObject <hailo_platform.pyhailort.hw_object.HailoHWObject>`.
                If no HEFs are loaded or this flag is set to False, the HEF will be
                compiled. Relevant if the requested inference target is a
                :class:`HailoHWObject <hailo_platform.pyhailort.hw_object.HailoHWObject>`.
                Defaults to False.
            mapping_timeout (int, optional): Compilation timeout for the whole run. By default,
                the timeout is calculated dynamically based on the model size.
            allocator_script_filename (str, optional): Model script allowing fine-tuning of
                allocation. If present, the Allocator parses it command-by-command and executes.
            network_groups (list, optional): A list of network groups received from
                :func:`~hailo_platform.pyhailort.hw_object.HailoHWObject.configure`.

        Returns:
            :class:`HailoGraphExport <hailo_sdk_common.export.hailo_graph_export.HailoGraphExport>`:
            An object that holds the new tensors that have been added to the graph,
            and serialized HEF in case the target is hardware and the HEF is not previously loaded.

        """
        return self._get_hef_tf_graph(
            target,
            nodes,
            translate_input,
            rescale_output,
            custom_session,
            twin_mode,
            native_layers,
            fps,
            use_preloaded_compilation,
            mapping_timeout,
            allocator_script_filename,
            network_groups,
        )

    def compile(self):
        """
        DFC API for compiling current model to Hailo hardware.

        Returns:
            bytes: Data of the HEF that contains the hardware representation of this model.

        Example:
            >>> runner = ClientRunner(har="my_model.har")
            >>> compiled_model = runner.compile()

        """
        return self._compile()

    @contextmanager
    def _hef_infer_context(self, hailo_export):
        if hailo_export.hef_infer_wrapper is None:
            yield hailo_export
        else:
            if not self._hailo_platform:
                raise HailoPlatformMissingException(
                    "The HailoRT Python API (hailo_platform package) must be installed "
                    "in order to use the Hailo hardware",
                )
            infer_wrapper = hailo_export.hef_infer_wrapper
            output_format_type = (
                self._hailo_platform.FormatType.FLOAT32
                if infer_wrapper.rescale_output
                else self._hailo_platform.FormatType.AUTO
            )
            input_format_type = (
                self._hailo_platform.FormatType.FLOAT32
                if infer_wrapper.translate_input
                else self._hailo_platform.FormatType.AUTO
            )
            network = hailo_export.network_groups[0]
            input_vstreams_params = self._hailo_platform.InputVStreamParams.make_from_network_group(
                configured_network=network,
                format_type=input_format_type,
            )
            output_vstreams_params = self._hailo_platform.OutputVStreamParams.make_from_network_group(
                configured_network=network,
                format_type=output_format_type,
            )
            application_params = network.create_params()
            with network.activate(application_params), self._hailo_platform.InferVStreams(
                network,
                input_vstreams_params,
                output_vstreams_params,
                tf_nms_format=True,
            ) as infer_pipeline:
                infer_wrapper.infer_pipeline = infer_pipeline
                yield hailo_export

    @contextmanager
    def infer_context(
        self,
        inference_context: InferenceContext,
        device_ids=None,
        nms_score_threshold=None,
        gpu_policy: DistributionStrategy = DistributionStrategy.AUTO,
        custom_infer_config: FlowCommands = None,
        lora_adapter_name: Optional[str] = None,
    ) -> Generator[ContextInfo, None, None]:
        """DFC API for generating context for inference.
        The context must be used with the `infer` API.

        Args:
            inference_context (:class:`~hailo_sdk_client.exposed_definitions.InferenceContext`): Enum to control which
                inference types to use.
            device_ids (list of str, optional): device IDs to create VDevice from, call :func:`Device.scan` to get
                a list of all available devices. Excludes 'params'.
            nms_score_threshold (float, optional): score threshold filtering for on device nms.
                Relevant only when nms is used.
            custom_infer_config: debugging capabilities for pinpointing sources of noise in the optimization process. See Flow State Handler.
            gpu_policy(str, Optional): Sets the gpu policy for emulation based inference,
                AUTO will distribute the inference
                across available GPUS using a Mirrored Strategy (Parallel DATA)

            lora_adapter_name (str, optional): optional argument to specify the lora adapter name for inference.

        Raises:
            HailoPlatformMissingException: In case, HW inference is requested but HailoRT is not installed.
            InvalidArgumentsException: In case, InferenceContext is not recognized.

        Example:
            >>> with runner.infer_context(InferenceContext.SDK_NATIVE) as ctx:
            >>>     result = runner.infer(
            ...         ctx,
            ...         dataset=tf.data.Dataset.from_tensor_slices(np.ones((1, 10))),
            ...         batch_size=1
            ...     )

        """
        if nms_score_threshold is not None:
            context_to_model = {
                InferenceContext.SDK_NATIVE: self._native_model,
                InferenceContext.SDK_FP_OPTIMIZED: self._fp_model,
            }
            hn_model = context_to_model.get(inference_context, self._hn)
            if hn_model.has_postprocess_layer():
                self._logger.info(f"Setting NMS score threshold to {nms_score_threshold}")
            else:
                self._logger.info(
                    "CPU postprocess does not exist in the model, ignoring the given argument `nms_score_threshold`",
                )

        if inference_context in [
            InferenceContext.SDK_NATIVE,
            InferenceContext.SDK_FP_OPTIMIZED,
            InferenceContext.SDK_QUANTIZED,
            InferenceContext.SDK_BIT_EXACT,
        ]:
            context_info = InternalContextInfo(
                infer_context=inference_context,
                open=True,
                gpu_policy=DistributionStrategy(gpu_policy),
                flow_commands=custom_infer_config,
                lora_adapter_name=lora_adapter_name,
            )

            with self._sdk_backend.override_nms_score_threshold(nms_score_threshold):
                yield context_info
            context_info.open = False

        elif inference_context == InferenceContext.SDK_HAILO_HW:
            if DistributionStrategy(gpu_policy) not in {DistributionStrategy.AUTO, DistributionStrategy.SINGLE}:
                raise InvalidArgumentsException(
                    f"Gpu Policy: {gpu_policy.name} is not supported with InferenceContext {inference_context.name}",
                )

            if not self._hailo_platform:
                raise HailoPlatformMissingException(
                    "The HailoRT Python API (hailo_platform package) must be installed to use the Hailo hardware",
                )

            # Setting VDevice params to disable the HailoRT service feature
            params = self._hailo_platform.VDevice.create_params()

            #  Enable the HailoRT service feature
            params.multi_process_service = self._use_service
            params.scheduling_algorithm = (
                self._hailo_platform.HailoSchedulingAlgorithm.ROUND_ROBIN
                if self._use_service or self.hw_arch == "hailo10h"
                else self._hailo_platform.HailoSchedulingAlgorithm.NONE
            )

            # change this block when async api on hailo8 is enabled SDK-51150
            if self.hw_arch == "hailo10h":
                with self._hailo_platform.VDevice(device_ids=device_ids, params=params) as target:
                    infer_model = target.create_infer_model(self._hef)
                    for output in infer_model.outputs:
                        if output.is_nms and nms_score_threshold is not None:
                            output.set_nms_score_threshold(nms_score_threshold)
                        output.set_format_type(self._hailo_platform.FormatType.FLOAT32)

                    for model_input in infer_model.inputs:
                        model_input.set_format_type(self._hailo_platform.FormatType.FLOAT32)

                    with infer_model.configure() as configured_infer_model:
                        infer_wrapper = HefInferWrapper(
                            infer_model,
                            target,
                            rescale_output=True,
                            translate_input=True,
                            output_names=[layer.name for layer in self._hn.get_real_output_layers(False)],
                            configured_infer_model=configured_infer_model,
                        )
                        output_format = [np.float32] * len(infer_model.output_names)
                        context_info = InternalContextInfo(
                            infer_context=inference_context,
                            open=True,
                            graph_export=GraphExport(None, infer_wrapper, output_format),
                            gpu_policy=gpu_policy,
                            lora_adapter_name=lora_adapter_name,
                        )
                        yield context_info
                        context_info.open = False
            else:
                # old sync api
                with self._hailo_platform.VDevice(device_ids=device_ids, params=params) as target:
                    output_format_type = self._hailo_platform.FormatType.FLOAT32
                    input_format_type = self._hailo_platform.FormatType.FLOAT32
                    network_groups = self._get_network_groups_for_inference(target, inference_context)
                    nodes = self._get_keras_input_nodes()
                    output_tensors = [
                        tf.float32 for _ in self._hn.get_output_layers(remove_non_neural_core_layers=False)
                    ]
                    infer_wrapper = self._get_infer_wrapper(nodes=nodes, target=target, network_groups=network_groups)
                    network = network_groups[0]
                    input_vstreams_params = self._hailo_platform.InputVStreamParams.make_from_network_group(
                        configured_network=network, format_type=input_format_type
                    )
                    output_vstreams_params = self._hailo_platform.OutputVStreamParams.make_from_network_group(
                        configured_network=network, format_type=output_format_type
                    )
                    application_params = network.create_params()
                    with network.activate(application_params), self._hailo_platform.InferVStreams(
                        network,
                        input_vstreams_params,
                        output_vstreams_params,
                        tf_nms_format=True,
                    ) as infer_pipeline:
                        if nms_score_threshold is not None:
                            infer_pipeline.set_nms_score_threshold(nms_score_threshold)
                        infer_wrapper.infer_pipeline = infer_pipeline
                        context_info = InternalContextInfo(
                            infer_context=inference_context,
                            open=True,
                            graph_export=GraphExport(network_groups, infer_wrapper, output_tensors),
                            gpu_policy=gpu_policy,
                            lora_adapter_name=lora_adapter_name,
                        )
                        yield context_info
                        context_info.open = False
        else:
            raise InvalidArgumentsException(f"inference_context {inference_context} is not supported.")

    @allowed_states(*HN_STATES)
    def _compile(self, fps=None, mapping_timeout=None, allocator_script_filename=None):
        orig_state = self.state
        if allocator_script_filename:
            if self.model_script:
                self._logger.warning(
                    f"Taking model script commands from {allocator_script_filename} and ignoring "
                    f"previous allocation script commands",
                )
            self.load_model_script(allocator_script_filename)

        serialized_hef = self._sdk_backend.compile(fps, self.model_script, mapping_timeout)

        self._auto_model_script = self._sdk_backend.get_auto_alls()
        self._state = States.COMPILED_SLIM_MODEL if orig_state in SLIM_STATES else States.COMPILED_MODEL
        self._hef = bytes(serialized_hef)
        return self._hef

    @allowed_states(States.UNINITIALIZED, States.ORIGINAL_MODEL, States.HAILO_MODEL)
    def translate_onnx_model(
        self,
        model=None,
        net_name="model",
        start_node_names=None,
        end_node_names=None,
        net_input_shapes=None,
        augmented_path=None,
        disable_shape_inference=False,
        disable_rt_metadata_extraction=False,
        net_input_format=None,
        **kwargs,
    ):
        """
        DFC API for parsing an ONNX model. This creates a runner with loaded HN (model) and
        parameters.

        Args:
            model (str or bytes or pathlib.Path): Path or bytes of the ONNX model file to parse.
            net_name (str): Name of the new HN to generate.
            start_node_names (list of str, optional): List of ONNX nodes that parsing will start from.
            end_node_names (list of str, optional): List of ONNX nodes, that the parsing can stop
                after all of them are parsed.
            net_input_shapes (dict or list, optional): A dictionary describing the input shapes for
                each of the start nodes given in start_node_names, where the keys are the names of
                the start nodes and the values are their corresponding input shapes.
                Use only when the original model has dynamic input shapes (described with a wildcard
                denoting each dynamic axis, e.g. [b, c, h, w]).
                Can be a list (e.g. [b, c, h, w]) for a single input network.
            augmented_path: Path to save a modified model, augmented with tensors names (where applicable).
            disable_shape_inference: When set to True, shape inference with ONNX runtime will be disabled.
            disable_rt_metadata_extraction: When set to True, runtime metadata extraction will be disabled. Generating
                a model using :func:`~hailo_sdk_client.runner.client_runner.ClientRunner.get_hailo_runtime_model` won't
                be supported in this case.
            net_input_format: (dict of str to list of :class:`~hailo_sdk_client.exposed_definitions.Dims`, optional):
                A dictionary describing the input format for each of the start nodes given in start_node_names, where
                the keys are the names of the start nodes and the values are their corresponding input format (list of
                :class:`~hailo_sdk_client.exposed_definitions.Dims`).
                The defaults are as follows:
                - rank 2 input: [Dims.BATCH, Dims.CHANNELS]
                - rank 3 input: [Dims.BATCH, Dims.WIDTH, Dims.CHANNELS]
                - rank 4 input: [Dims.BATCH, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH]
                - rank 5 input: [Dims.BATCH, Dims.CHANNELS, Dims.DISPARITY, Dims.HEIGHT, Dims.WIDTH]
                usage example: net_input_format={'Conv_3': [Dims.BATCH, Dims.CHANNELS, Dims.HEIGHT, Dims.WIDTH]}
        Note:
            Using a non-default ``start_node_names`` requires the model to be
            shape inference compatible, meaning either it has a real input shape, or, in the case of a
            dynamic input shape, the ``net_input_shapes`` field is provided to specify the input
            shapes of the given start nodes. The order of the output nodes is determined by the order of the
            end_node_names.

        Returns:
            tuple: The first item is the HN JSON as a string. The second item is the params dict.

        """
        parser = Parser()
        parser.translate_onnx_model(
            model=model,
            net_name=net_name,
            start_node_names=start_node_names,
            end_node_names=end_node_names,
            net_input_shapes=net_input_shapes,
            augmented_path=augmented_path,
            disable_shape_inference=disable_shape_inference,
            disable_rt_metadata_extraction=disable_rt_metadata_extraction,
            net_input_format=net_input_format,
            **kwargs,
        )
        return self._finalize_parsing(parser.return_data)

    @allowed_states(States.UNINITIALIZED, States.ORIGINAL_MODEL, States.HAILO_MODEL)
    def translate_tf_model(
        self,
        model_path=None,
        net_name="model",
        start_node_names=None,
        end_node_names=None,
        tensor_shapes=None,
    ):
        """
        DFC API for parsing a TF model given by a checkpoint/pb/savedmodel/tflite file. This creates
        a runner with loaded HN (model) and parameters.

        Args:
            model_path (str): Path of the file to parse.
                Possible formats (recommend to move to TFLite, see user guide for more details):
                * SavedModel (TF2): [Deprecated] Saved model export from Keras, file named saved_model.pb|pbtxt from the model dir.
                * TFLite: Tensorflow lite model, converted from ckpt/frozen/Keras to file with .tflite suffix.
            net_name (str): Name of the new HN to generate.
            start_node_names (list of str, optional): List of TensorFlow nodes that parsing will start from.
                If this parameter is specified, start_node_name should remain empty.
            end_node_names (list of str, optional): List of Tensorflow nodes, which the parsing can
                stop after all of them are parsed.
            tensor_shapes (dict, optional): [Deprecated] A dictionary containing names of tensors and shapes to
                set in the TensorFlow graph. Use only for placeholders with a wildcard shape.

        Note:
            * The order of the output nodes is determined by the order of the end_node_names.
            * TF1.x (.ckpt/.pb) and TF2.x (.pb) models support were deprecated, it is recommended to use TFLite models (see user guide for more details).

        Returns:
            tuple: The first item is the HN JSON, as a string. The second item is the params dict.

        Example:
            >>> model = keras.Sequential(
            ...    layers.Conv2D(32, 3, activation="relu"),
            ...    layers.Conv2D(64, 3, activation="relu"),
            ...    layers.MaxPooling2D(3)])
            >>> model.predict(random.uniform(shape=(1, 32, 32, 3), minval=-1, maxval=1))
            >>> converter = tf.lite.TFLiteConverter.from_keras_model(model)
            >>> tflite_model = converter.convert()
            >>> with tf.io.gfile.GFile('my_model.tflite', "wb") as f:
            ...    f.write(tflite_model)
            >>> runner = ClientRunner(hw_arch='hailo8')
            >>> hn, params = runner.translate_tf_model(
            ...     'my_model.tflite', 'MyCoolModel', ['sequential/Conv1'], ['sequential/Maxpool'])

        """
        parser = Parser()
        parser.translate_tf_model(
            model_path=model_path,
            net_name=net_name,
            start_node_names=start_node_names,
            end_node_names=end_node_names,
            tensor_shapes=tensor_shapes,
        )
        return self._finalize_parsing(parser.return_data)

    def _finalize_parsing(self, return_data):
        self._preprocess_model = return_data.get("preprocess_model")
        self._postprocess_model = return_data.get("postprocess_model")
        self.original_model_meta = return_data["original_model_meta"]
        self._original_model_path = return_data["original_model_meta"].get("original_model_path")

        # generate new runner state
        self.set_hn(return_data["hn_data"])
        self.load_params(return_data["bn_rescaled_params"])
        native_npz = {**self.get_params()}

        detected_anchors = self.original_model_meta.get("detected_anchors")
        if detected_anchors:
            meta_arch = detected_anchors["meta_arch"]
            full_config = self._get_nms_full_config(meta_arch)
            engine = NMSPostprocessCommand.get_default_engine_from_meta_arch(meta_arch)
            self._sdk_backend.nms_metadata = NMSMetaData(
                NMSConfig.from_json(full_config, meta_arch),
                meta_arch,
                engine,
                full_config,
            )

        self._transpose_h_w_warning()
        return return_data["hn_data"], native_npz

    def _transpose_h_w_warning(self):
        msg = ""
        hailo_nn = self.get_hn_model()
        start_layers = hailo_nn.get_real_input_layers()
        msg += self._find_transpose_h_w(start_layers, start_nodes=True)
        end_layers = hailo_nn.get_real_output_layers()
        msg += self._find_transpose_h_w(end_layers, start_nodes=False)
        if msg:
            self._logger.warning(msg)

    def _find_transpose_h_w(self, layers, start_nodes=True):
        format_conversions = [x for x in layers if x.op == LayerType.format_conversion]
        trans_hw = [x for x in format_conversions if x.conversion_type == FormatConversionType.transpose_height_width]
        if trans_hw:
            location = "beginning" if start_nodes else "end"
            return (
                f"Found inefficient layer(s) at the {location} of the model: "
                f"{', '.join([x.full_name_msg for x in trans_hw])}. Please consider offloading it to the host, "
                "by cutting it out of the model using the start/end node names arguments."
            )
        return ""

    @allowed_states(States.HAILO_MODEL, States.FP_OPTIMIZED_MODEL, States.QUANTIZED_MODEL, States.QUANTIZED_SLIM_MODEL)
    def join(self, runner, scope1_name=None, scope2_name=None, join_action=JoinAction.NONE, join_action_info=None):
        """
        DFC API to join two models, so they will be compiled together.

        Args:
            runner (:class:`~hailo_sdk_client.runner.client_runner.ClientRunner`): The client runner
                to join to this one.
            scope1_name (dict or str, optional): In case dict is given, mapping between existing scope names to new
                scope names for the layers of this model (see example below).
                In case str is given, the scope name will be used for all layers of this model. A string can be used
                only when there is a single scope name.
            scope2_name (dict or str, optional): Same as `scope1_name` for the runner to join.

        Example:
                >>> net1_scope_names = {'net1_scope1': 'net_scope1',
                ...                     'net1_scope2': 'net_scope2'}
                >>> net2_scope_names = {'net2': 'net_scope3'}
                >>> runner1.join(runner2, scope1_name=net1_scope_names,
                ...              scope2_name=net2_scope_names)
            join_action (:class:`~hailo_sdk_client.exposed_definitions.JoinAction`, optional): Type
                of action to run in addition to joining the models:

                * :attr:`~hailo_sdk_client.exposed_definitions.JoinAction.NONE`: Join the graphs
                  without any connection between them.
                * :attr:`~hailo_sdk_client.exposed_definitions.JoinAction.AUTO_JOIN_INPUTS`: Automatically
                  detect inputs for both graphs and combines them into one. This only works when both
                  networks have a single input of the same shape.
                * :attr:`~hailo_sdk_client.exposed_definitions.JoinAction.AUTO_CHAIN_NETWORKS`: Automatically
                  detect the output of this model and the input of the other model, and connect them. Only
                  works when this model has a single output, and the other model has a single input,
                  of the same shape.
                * :attr:`~hailo_sdk_client.exposed_definitions.JoinAction.CUSTOM`: Supply a custom
                  dictionary ``join_action_info``, which specifies which nodes from this model need
                  to be connected to which of the nodes in the other graph. If keys and values are
                  inputs, the inputs are joined. If keys are outputs, and values are inputs, the
                  networks are chained as described in the dictionary.

            join_action_info (dict, optional): Join information to be given when ``join_action`` is
                :attr:`~hailo_sdk_client.exposed_definitions.JoinAction.NONE`, as explained above.

        Example:
            >>> info = {'net1/output_layer1': 'net2/input_layer2',
            ...         'net1/output_layer2': 'net2/input_layer1'}
            >>> runner1.join(runner2, join_action=JoinAction.CUSTOM, join_action_info=info)

        """
        if runner._state != self._state:
            raise UnsupportedRunnerJoinException(
                "Joining runners is only allowed when both contain native graphs, or both contain quantized graphs",
            )
        if self._state in QUANTIZED_STATES and join_action != JoinAction.NONE:
            raise UnsupportedRunnerJoinException(
                "Merging inputs or chaining networks is not supported for quantized models",
            )
        if join_action == JoinAction.CUSTOM and not join_action_info:
            raise UnsupportedRunnerJoinException(
                "join_action is JoinAction.CUSTOM, but join_action_info was not specified",
            )

        scope_names1 = get_runner_scopes(scope1_name, self._hn.net_params.net_scopes)
        scope_names2 = get_runner_scopes(scope2_name, runner._hn.net_params.net_scopes)
        scopes_intersection = set(scope_names1.values()).intersection(set(scope_names2.values()))
        if len(scopes_intersection) > 0:
            raise UnsupportedRunnerJoinException(
                f"The scopes for the two networks intersect in these scopes {scopes_intersection}, and"
                " cannot be joined",
            )

        model_script1 = self.model_script if self.model_script else ""
        model_script2 = runner.model_script if runner.model_script else ""
        model_script1, model_script2 = handle_model_script_join(
            self.get_hn_model(),
            model_script1,
            scope_names1,
            runner.get_hn_model(),
            model_script2,
            scope_names2,
            self._state,
            join_action,
        )

        params_translated = self._sdk_backend.get_params_translated()
        params_fp_optimized = self._sdk_backend.get_params_fp_optimized()
        params = self._sdk_backend.get_params()
        net1_q_params = ModelParams(params_translated) if self._state in QUANTIZED_STATES else None
        net1_fp_opt_params = ModelParams(params_fp_optimized) if (params_fp_optimized is not None) else None
        net1_params = ModelParams(params) if params else None
        runner_params_translated = runner._sdk_backend.get_params_translated()
        runner_params_fp_optimized = runner._sdk_backend.get_params_fp_optimized()
        runner_params = runner._sdk_backend.get_params()
        net2_q_params = ModelParams(runner_params_translated) if self._state in QUANTIZED_STATES else None
        net2_fp_opt_params = ModelParams(runner_params_fp_optimized) if runner_params_fp_optimized else None
        net2_params = ModelParams(runner_params) if runner_params else None

        new_hn = merge_hns(self._hn, runner._hn, scope_names1, scope_names2, join_action, join_action_info)
        self._sdk_backend = self._generate(new_hn)
        self._model_name = new_hn.name

        if net1_params and net2_params:
            self.load_params(
                merge_params(net1_params, net2_params, ParamsKinds.NATIVE, scope_names1, scope_names2),
                params_kind=ParamsKinds.NATIVE,
            )
        if net1_q_params and net2_q_params:
            self.load_params(
                merge_params(net1_q_params, net2_q_params, ParamsKinds.TRANSLATED, scope_names1, scope_names2),
                params_kind=ParamsKinds.TRANSLATED,
            )

        if net1_fp_opt_params and net2_fp_opt_params:
            self.load_params(
                merge_params(
                    net1_fp_opt_params,
                    net2_fp_opt_params,
                    ParamsKinds.FP_OPTIMIZED,
                    scope_names1,
                    scope_names2,
                ),
                params_kind=ParamsKinds.FP_OPTIMIZED,
            )

        self._sdk_backend.load_model_script(model_script1 + model_script2)

        if self.original_model_meta is not None:
            if "errors" not in self.original_model_meta:
                self.original_model_meta["errors"] = []
            self.original_model_meta["errors"].append("The network was changed due to join action")

    @allowed_states(*HN_STATES)
    def profile(
        self,
        should_use_logical_layers=True,
        hef_filename=None,
        runtime_data=None,
        stream_fps=None,
    ):
        """
        DFC API of the Profiler.

        Args:
            hef_filename (str, optional): HEF file path. If given, the HEF file is used. If not
                given and the HEF from the previous compilation is cached, the cached HEF is used;
                Otherwise, the automatic mapping tool is used. Use
                :func:`~hailo_sdk_client.runner.client_runner.ClientRunner.compile`
                to generate and set the HEF. Only in post-placement mode. Defaults to None.
            should_use_logical_layers (bool, optional): Indicates whether the Profiler should
                combine all physical layers into their original logical layer in the report.
                Defaults to True.
            runtime_data (str, optional): runtime_data.json file path produced by hailortcli run2 measure-fw-actions.
            stream_fps (float, optional): FPS used for power and bandwidth calculation.

        Returns:
            tuple: The first item is a JSON with the profiling result summary. The second item is a
            CSV table with detailed profiling information about all model layers. The third item is
            the latency data. Fourth is accuracy data.

        Example:
            >>> runner = ClientRunner(har="my_model.har")
            >>> export = runner.profile()

        """
        self._logger.info(f"Running profile for {self.model_name} in state {self.state.value}")

        hef = self._hef
        if hef_filename is not None:
            with open(hef_filename, "rb") as f:
                hef = f.read()

        if hef is None:
            invalid_args = []
            if not should_use_logical_layers:
                invalid_args.append("should_use_logical_layers")
            if runtime_data is not None:
                invalid_args.append("runtime_data")
            if stream_fps is not None:
                invalid_args.append("stream_fps")
            if invalid_args:
                invalid_args_str = ", ".join(invalid_args)
                raise InvalidArgumentsException(
                    f"{invalid_args_str} can not be given when runner state is {self.state.value}",
                )

        return self._sdk_backend.profile(
            should_use_logical_layers=should_use_logical_layers,
            hef=hef,
            runtime_data=runtime_data,
            stream_fps=stream_fps,
        )

    @allowed_states(States.COMPILED_MODEL, States.COMPILED_SLIM_MODEL)
    def save_autogen_allocation_script(self, path):
        """
        DFC API for retrieving listed operations of the last allocation in .alls format.

        Args:
            path (str): Path where the script is saved.

        Returns:
            bool: False if an autogenerated script was not created; otherwise it returns True.

        """
        if self._auto_model_script:
            with open(path, "w") as f:
                f.write(self._auto_model_script)
            return True

        self._logger.debug("Auto-generated model script was not written in this session.")
        return False

    @staticmethod
    def _save_params(params, path):
        with open(path, "wb") as save:
            # TODO: SDK-10099
            save_params(save, dict(iter(params.items())), type_=".npz")

    @property
    def model_name(self):
        """Get the current model (network) name."""
        return self._model_name

    @property
    def model_optimization_commands(self):
        return self._sdk_backend.model_optimization_commands

    @property
    def hw_arch(self):
        return self._hw_arch

    @property
    def state(self):
        """Get the current model state."""
        return self._state

    @property
    def hef(self):
        """Get the latest HEF compilation."""
        return self._hef

    @hef.setter
    def hef(self, hef):
        try:
            self._hailo_platform.HEF(hef)
        except Exception as e:
            raise InvalidArgumentsException(f"Got invalid HEF file in setter {e}")
        self._hef = hef

    @property
    def _hn_dict(self):
        if self._hn is None:
            raise HNNotSetException("The HN has not been set yet for current runner")
        return self._hn.to_hn(self._hn.name, json_dump=False)

    @property
    def _native_hn_dict(self):
        if self._native_model is None:
            raise HNNotSetException("The native HN has not been set yet for current runner")
        return self._native_model.to_hn(self._native_model.name, json_dump=False)

    @property
    def _fp_hn_dict(self):
        if self._fp_model is None:
            raise HNNotSetException("The full-precision HN has not been set yet for current runner")
        return self._fp_model.to_hn(self._fp_model.name, json_dump=False)

    @property
    def nms_config_file(self):
        return self._sdk_backend.nms_metadata.config_file if self._sdk_backend.nms_metadata else None

    @property
    def nms_engine(self):
        return self._sdk_backend.nms_metadata.engine if self._sdk_backend.nms_metadata else None

    @property
    def nms_meta_arch(self):
        return self._sdk_backend.nms_metadata.meta_arch if self._sdk_backend.nms_metadata else None

    @staticmethod
    def _get_params(keys, params):
        if params is None:
            return None
        if keys is None:
            keys = params.keys()
        return ModelParams({k: v for k, v in params.items() if k in keys})

    @allowed_states(*HN_STATES)
    def get_params(self, keys=None):
        """
        Get the native (non-quantized) params the runner uses.

        Args:
            keys (list of str, optional): List of params to retrieve. If not specified, all params
                are retrieved.

        """
        return self._get_params(keys, self._sdk_backend.get_params())

    @allowed_states(*HN_STATES)
    def get_params_translated(self, keys=None):
        """
        Get the quantized params the SDK uses.

        Args:
            keys (list of str, optional): List of params to retrieve. If not specified, all params
                are retrieved.

        """
        return self._get_params(keys, self._sdk_backend.get_params_translated())

    @allowed_states(*HN_STATES)
    def get_params_fp_optimized(self, keys=None):
        """
        Get the fp optimized params.

        Args:
            keys (list of str, optional): List of params to retrieve. If not specified, all params
                are retrieved.

        """
        return self._get_params(keys, self._sdk_backend.get_params_fp_optimized())

    @allowed_states(*HN_STATES)
    def get_params_statistics(self, keys=None):
        """
        Get the optimization statistics.
        During the optimization stage, statistics about the model and the optimization algorithms are gathered.
        This method returns this information in a ModelParams structure.

        Args:
            keys (list of str, optional): List of params to retrieve. If not specified, all params
                are retrieved.

        """
        return self._get_params(keys, self._sdk_backend.get_params_statistics())

    def _hn_api_deprecation_warning(self):
        # TODO: https://hailotech.atlassian.net/browse/SDK-46364
        self._logger.deprecation_warning(
            "All ClientRunner APIs for getting/setting/saving the HN model will be deprecated in the near future. "
            "Please use Hailo archive APIs for model inspection and visualization.",
            DeprecationVersion.APR2024,
        )

    @allowed_states(*HN_STATES)
    def get_hn_str(self):
        """Get the HN JSON after serialization to a formatted string."""
        self._hn_api_deprecation_warning()
        if self._hn is None:
            # msg: Raised when tried to get the HN from the runner before it is set.
            raise HNNotSetException("The HN has not been set yet for current runner")
        return self._hn.to_hn(self._hn.name)

    @allowed_states(*HN_STATES)
    def get_hn_dict(self):
        """Get the HN of the current model as a dictionary."""
        return self._hn_dict

    @allowed_states(*HN_STATES)
    def get_hn(self):
        """Get the HN of the current model as a dictionary."""
        # exists for backward compatibility
        return self.get_hn_dict()

    @allowed_states(*HN_STATES)
    def get_hn_model(self):
        """Get the :class:`~hailo_sdk_common.hailo_nn.hailo_nn.HailoNN` object of the current model."""
        return HailoNN.from_parsed_hn(self.get_hn_dict())

    @allowed_states(*HN_STATES)
    def get_native_hn_str(self):
        """Get the HN JSON after serialization to a formatted string."""
        if self._sdk_backend.native_model is None:
            # msg: Raised when tried to get the HN from the runner before it is set.
            raise HNNotSetException("The HN has not been set yet for current runner")
        return self._sdk_backend.native_model.to_hn(self._sdk_backend.native_model.name)

    @allowed_states(*HN_STATES)
    def get_native_hn_dict(self):
        """Get the HN of the current model as a dictionary."""
        return self._native_hn_dict

    @allowed_states(*HN_STATES)
    def get_native_hn(self):
        """Get the HN of the current model as a dictionary."""
        # exists for backward compatibility
        return self.get_native_hn_dict()

    @allowed_states(*HN_STATES)
    def get_native_hn_model(self):
        """Get the native :class:`~hailo_sdk_common.hailo_nn.hailo_nn.HailoNN` object of the current model."""
        return HailoNN.from_parsed_hn(self.get_native_hn_dict())

    @allowed_states(
        States.FP_OPTIMIZED_MODEL, States.QUANTIZED_BASE_MODEL, States.QUANTIZED_MODEL, States.COMPILED_MODEL
    )
    def get_fp_hn_str(self):
        """Get the full-precision HN JSON after serialization to a formatted string."""
        if self._sdk_backend.fp_model is None:
            raise HNNotSetException("The full-precision HN has not been set yet for current runner")
        return self._sdk_backend.fp_model.to_hn(self._sdk_backend.fp_model.name)

    @allowed_states(
        States.FP_OPTIMIZED_MODEL, States.QUANTIZED_BASE_MODEL, States.QUANTIZED_MODEL, States.COMPILED_MODEL
    )
    def get_fp_hn_dict(self):
        """Get the full-precision HN of the current model as a dictionary."""
        return self._fp_hn_dict

    @allowed_states(
        States.FP_OPTIMIZED_MODEL, States.QUANTIZED_BASE_MODEL, States.QUANTIZED_MODEL, States.COMPILED_MODEL
    )
    def get_fp_hn_model(self):
        """Get the full-precision :class:`~hailo_sdk_common.hailo_nn.hailo_nn.HailoNN` object of the current model."""
        return HailoNN.from_parsed_hn(self._fp_hn_dict)

    @allowed_states(States.UNINITIALIZED, States.ORIGINAL_MODEL, States.HAILO_MODEL, States.FP_OPTIMIZED_MODEL)
    def set_hn(self, hn):
        """
        Set the HN of the current model.

        Args:
            hn: Hailo network description (HN), as a file-like object, string, dict or
                :class:`~hailo_sdk_common.hailo_nn.hailo_nn.HailoNN`.

        """
        if not self._hw_arch:
            self._logger.warning(
                f'hw_arch parameter not given, using the default hw_arch {DEFAULT_HW_ARCH}.\n'
                f'If another device is the target, please run again using one of {", ".join(SUPPORTED_HW_ARCHS)}',
            )
            self._hw_arch = DEFAULT_HW_ARCH
        self._sdk_backend = self._generate(hn)
        self._model_name = self._hn_dict.get("name")
        self._state = States.HAILO_MODEL

    @staticmethod
    def _load_hn(hn):
        if isinstance(hn, string_types):
            return HailoNN.from_hn(hn)
        elif isinstance(hn, file_types):
            return HailoNN.from_parsed_hn(json.load(hn))
        elif isinstance(hn, dict):
            return HailoNN.from_parsed_hn(hn)
        elif isinstance(hn, HailoNN):
            return hn
        else:
            raise TypeError(f"Bad type for hn: {type(hn).__name__}, hn must be file, dict, str, or unicode")

    @property
    def _native_model(self):
        if self._sdk_backend is not None:
            return self._sdk_backend.native_model
        return None

    @property
    def _fp_model(self):
        if self._sdk_backend is not None:
            return self._sdk_backend.fp_model
        return None

    @allowed_states(*HN_STATES)
    def save_hn(self, path):
        """
        Save the HN of the current model.

        Args:
            path (str): Path where the hn file is saved.

        """
        self._hn_api_deprecation_warning()
        with open(path, "w") as f:
            f.write(self.get_hn_str())

    @allowed_states(States.HAILO_MODEL, States.FP_OPTIMIZED_MODEL, States.QUANTIZED_MODEL, States.COMPILED_MODEL)
    def save_native_hn(self, path):
        """
        Save the HN of the current model.

        Args:
            path (str): Path where the hn file is saved.

        """
        self._hn_api_deprecation_warning()
        with open(path, "w") as f:
            f.write(self.get_native_hn_str())

    @allowed_states(*INITIALIZED_STATES)
    def save_har(self, har_path, compressed=False, save_original_model=False, compilation_only=False):
        """
        Save the current model serialized as Hailo Archive file.

        Args:
            har_path: Path for the created Hailo archive directory.
            compressed: Indicates whether to compress the archive file. Defaults to False.
            save_original_model: Indicates whether to save the original model (TF/ONNX) in the
                archive file. Defaults to False.
            compilation_only: Indicates whether to save a reduced size har, containing only compilation related data.

        """
        if self._state == States.ORIGINAL_MODEL:
            (
                params_statistics,
                params_translated,
                params_fp_opt,
                params,
                nms_metadata,
                mo_flavor,
                flavor_config,
                params_hailo_opt,
                modifications_meta_data,
                optimization_flow_memento,
                lora_weights_metadata,
            ) = (None,) * 11
        else:
            params_statistics = self._sdk_backend.get_params_statistics()
            params_translated = self._sdk_backend.get_params_translated()
            params_fp_opt = self._sdk_backend.get_params_fp_optimized()
            params_hailo_opt = self._sdk_backend.get_params_hailo_optimized()
            params = self._sdk_backend.get_params()
            nms_metadata = self._sdk_backend.nms_metadata
            mo_flavor = self._sdk_backend.mo_flavor
            flavor_config = self._sdk_backend.flavor_config
            modifications_meta_data = self.modifications_meta_data
            optimization_flow_memento = self._sdk_backend.get_flow_memento()
            lora_weights_metadata = self._sdk_backend.lora_weights_metadata

        # backward compatibility
        if (
            self.state in [States.FP_OPTIMIZED_MODEL, States.QUANTIZED_MODEL, States.COMPILED_MODEL]
            and not self._fp_model
        ):
            self._sdk_backend.update_fp_model(self._hn)
        if compilation_only and self.state in [States.ORIGINAL_MODEL, States.HAILO_MODEL, States.FP_OPTIMIZED_MODEL]:
            self._logger.warning(f"Compilation only mode is not supported in {self.state} state, ignoring.")
        har = HailoArchive(
            self._state,
            self._original_model_path,
            self._hn,
            self._native_model,
            self._fp_model,
            self._model_name,
            params,
            None,
            params_fp_opt,
            params_hailo_opt,
            params_translated,
            params_statistics,
            self.model_script,
            self._auto_model_script,
            self._force_weightless_model,
            self._hef,
            self._hw_arch,
            self.original_model_meta,
            self._preprocess_model,
            self._postprocess_model,
            nms_metadata,
            mo_flavor,
            flavor_config,
            modifications_meta_data,
            optimization_flow_memento=optimization_flow_memento,
            lora_weights_metadata=lora_weights_metadata,
        )
        har.save(
            har_path,
            compressed,
            save_original_model,
            compilation_only,
            params_serialization=ParamSerializationType.NPZ,
        )
        self._logger.info(f"Saved HAR to: {os.path.abspath(har_path)}")

    @allowed_states(States.UNINITIALIZED)
    def load_har(self, har=None):
        """
        Set the current model properties using a given Hailo Archive file.

        Args:
            har (str or :class:`~hailo_sdk_common.hailo_archive.hailo_archive.HailoArchive`): Path
                to the Hailo Archive file or an initialized HailoArchive object to restore.

        """
        har_path = None
        if isinstance(har, string_types):
            har_path = har

        if har_path is not None:
            if self._temp_dir is None:
                self._temp_dir = tempfile.TemporaryDirectory()
            har = HailoArchive.load(har_path, temp_dir=self._temp_dir.name)

        if self._hw_arch:
            if har.hw_arch and har.hw_arch != self._hw_arch:
                self._logger.warning(
                    f"hw_arch from HAR is {har.hw_arch} but client runner was initialized with "
                    f"{self._hw_arch}. Using {self._hw_arch}",
                )
        elif har.hw_arch:
            self._hw_arch = har.hw_arch
        else:
            self._hw_arch = har.LEGACY_DEFAULT_HW_ARCH

        self._state = har.state

        if self._state in HN_STATES:
            self._model_name = HailoNN.get_valid_input_identifier(har.model_name, "har_model_name")
            hn = self._load_hn(har.hn)
            native_hn = self._load_hn(har.native_hn) if har.native_hn is not None else hn
            native_hn_dict = native_hn.to_hn(native_hn.name, json_dump=False)
            self._sdk_backend = self._generate(native_hn_dict)
            self._sdk_backend.har = har
            self._sdk_backend.update_model(hn, override_config=True)
            if har.model_script:
                self._sdk_backend.load_model_script_from_har(har)
            if self._state in [States.QUANTIZED_MODEL, States.COMPILED_MODEL] + SLIM_STATES:
                hn.fill_default_quantization_params(logger=self._logger)

            if self._state in [
                States.FP_OPTIMIZED_MODEL,
                States.QUANTIZED_MODEL,
                States.QUANTIZED_BASE_MODEL,
                States.QUANTIZED_SLIM_MODEL,
            ]:
                self._sdk_backend.reapply_alls_commands_on_load_har(hn, har.params)

            self._sdk_backend.force_weightless_model = har.force_weightless_model
            self._sdk_backend.nms_metadata = har.nms_metadata
            self._sdk_backend.mo_flavor = har.mo_flavor
            self._sdk_backend.flavor_config = har.flavor_config
            if har.modifications_meta_data:
                self._sdk_backend.modifications_meta_data = har.modifications_meta_data
            if har.lora_weights_metadata:
                self._sdk_backend.lora_weights_metadata = har.lora_weights_metadata

        if har.params:
            self._sdk_backend.load_params(har.params, ParamsKinds.NATIVE)
        if har.params_after_bn:
            self._sdk_backend.load_params(har.params_after_bn)
        if har.params_statistics:
            self._sdk_backend.load_params(har.params_statistics, ParamsKinds.STATISTICS)

        if self._state in [
            States.FP_OPTIMIZED_MODEL,
            States.QUANTIZED_MODEL,
            States.QUANTIZED_BASE_MODEL,
            States.COMPILED_MODEL,
        ]:
            if har.fp_hn is not None:
                self._sdk_backend.update_fp_model(self._load_hn(har.fp_hn))
            if har.params_fp_opt is not None:
                self._sdk_backend.load_params(har.params_fp_opt, ParamsKinds.FP_OPTIMIZED)

        if self._state in [States.QUANTIZED_BASE_MODEL, States.QUANTIZED_MODEL, States.COMPILED_MODEL] + SLIM_STATES:
            if har.params_hailo_opt is not None:
                self._sdk_backend.load_params(har.params_hailo_opt, ParamsKinds.HAILO_OPTIMIZED)
            self._sdk_backend.load_params(har.params_translated, ParamsKinds.TRANSLATED)

        # load hn and load params are changing the given state
        self._state = har.state

        if self._state in [States.COMPILED_MODEL, States.COMPILED_SLIM_MODEL]:
            self._auto_model_script = har.auto_model_script
            self._hef = har.hef

        self._original_model_path = har.original_model_path

        self.original_model_meta = har.original_model_meta
        if har.original_model_meta is not None:
            if har.original_model_meta.get("detected_anchors") and har.original_model_meta["detected_anchors"].get(
                "meta_arch",
            ):
                self.original_model_meta["detected_anchors"]["meta_arch"] = NMSMetaArchitectures(
                    har.original_model_meta["detected_anchors"]["meta_arch"],
                )
            if har.original_model_meta.get("parsing_report"):
                report = self.original_model_meta["parsing_report"]
                if isinstance(report, str):  # backwards compatibility
                    report = json.loads(report)
                self.original_model_meta["parsing_report"] = ParsingReport(**report)
        self._preprocess_model = har.preprocess_model
        self._postprocess_model = har.postprocess_model

    @allowed_states(States.HAILO_MODEL, States.FP_OPTIMIZED_MODEL, States.QUANTIZED_BASE_MODEL)
    def _optimize(
        self,
        calib_data,
        data_type=None,
        work_dir=None,
        checkpoint: SupportedStops = SupportedStops.NONE,
        memento: Optional[FlowCheckPoint] = None,
    ) -> CheckpointInfo:
        # Setting containers
        data_continer = DatasetContianer(calib_data, data_type)
        checkpoint_info = CheckpointInfo(run_until=checkpoint, flow_memento=memento)

        # Fp Optimize if not Done Before
        if self._state == States.HAILO_MODEL:
            self._sdk_backend.optimize_full_precision(data_continer)
            self._state = States.FP_OPTIMIZED_MODEL

        #! If DataContiner.Data is None, optimize_full_precision will create a dataset
        #! and save in self._sdk_backend.calibration_data This cant be done before because
        #! optimize_full_precision might change the input shapes !!
        # Fix CalibSet Missing Data
        data_continer = self._sdk_backend.update_or_create_calib_data(data_continer)
        if self._state == States.QUANTIZED_BASE_MODEL:
            adapters = self._sdk_backend.get_lora_adapters()
            checkpoint_info = self._sdk_backend.lora_quantization(
                adapters[-1],
                data_continer,
                work_dir=work_dir,
                checkpoint_info=checkpoint_info,
            )
        else:
            checkpoint_info = self._sdk_backend.full_quantization(
                data_continer,
                work_dir=work_dir,
                checkpoint_info=checkpoint_info,
            )

        if checkpoint_info.quantization_done:
            self._state = States.QUANTIZED_MODEL

        return checkpoint_info

    def _get_batch_count(self, image_count, batch_size, calib_num_batch):
        if (image_count is not None) and calib_num_batch > (image_count // batch_size):
            new_batch_count = image_count // batch_size
            self._logger.warning(
                f"Dataset didn't have enough images for {calib_num_batch} batches with a batch size "
                f"of {batch_size}. Quantizing using {new_batch_count} batches",
            )
            return new_batch_count
        return calib_num_batch

    @allowed_states(*HN_STATES)
    def model_summary(self):
        """
        Prints summary of the model layers.
        """
        self.get_hn_model().summary()

    @allowed_states(States.HAILO_MODEL)
    def optimize_full_precision(self, calib_data=None, data_type=None):
        """
        Apply model optimizations to the model, keeping full-precision:
            1. Fusing various layers (e.g., conv and elementwise-add, fold batch_normalization, etc.),
               including folding of fused layers params.
            2. Apply model modification commands from the model script (e.g., resize input, transpose,
               color conversion, etc.)
            3. Run structural optimization algorithms (e.g., dead channels removal, tiling squeeze & excite, etc.)

        Args:
            calib_data (optional): Calibration data for optimization algorithms that require inference on actual input data.
                The type depends on the ``data_type`` parameter.
            data_type (optional, :class:`~hailo_sdk_client.exposed_definitions.CalibrationDataType`):
                calib_data's data type, based on enum values:

                * :attr:`~hailo_sdk_client.exposed_definitions.CalibrationDataType.auto` --
                  Automatically detected.
                * :attr:`~hailo_sdk_client.exposed_definitions.CalibrationDataType.np_array` --
                  ``numpy.ndarray``, or dictionary with input layer names as keys, and values types
                  of ``numpy.ndarray``.
                * :attr:`~hailo_sdk_client.exposed_definitions.CalibrationDataType.dataset` --
                  ``tensorflow.data.Dataset`` object with valid signature.
                  signature should be either ((h, w, c), image_info) or
                  ({'input_layer1': (h1, w1, c1), 'input_layer2': (h2, w2, c2)}, image_info)
                  image_info can be an empty dict for the quantization
                * :attr:`~hailo_sdk_client.exposed_definitions.CalibrationDataType.npy_file` --
                  path to a npy or npz file
                * :attr:`~hailo_sdk_client.exposed_definitions.CalibrationDataType.npy_dir` --
                  path to a npy or npz dir. Assumes the same shape for all the items

        """
        if calib_data is None and not (
            self._sdk_backend.calibration_data or self._sdk_backend.calibration_data_random_max
        ):
            self._logger.deprecation_warning(
                "Optimizing in full precision will require calibration data "
                "in the near future, to allow more accurate optimization "
                "algorithms which require inference on actual data.",
                DeprecationVersion.JUL2024,
            )
        data_continer = DatasetContianer(calib_data, data_type)
        self._optimize_full_precision(data_continer)

    def _optimize_full_precision(self, data_continer: DatasetContianer):
        self._sdk_backend.optimize_full_precision(data_continer)
        self._state = States.FP_OPTIMIZED_MODEL

    @allowed_states(States.QUANTIZED_MODEL, States.COMPILED_MODEL)
    def analyze_noise(
        self,
        dataset,
        data_type=InferenceDataType.auto,
        data_count: Optional[int] = None,
        batch_size: int = 1,
        analyze_mode: Optional[str] = None,
        **kwargs,
    ) -> dict:
        """
        Run layer noise analysis on a quantized model:
            * Analyze the model accuracy
            * Generate analysis data to be visualized in the Hailo Model profiler

        Args:
            dataset: data for analysis. The type depends on the ``data_type`` parameter.
            data_type (optional, :class:`~hailo_sdk_client.exposed_definitions.InferenceDataType`):
                dataset's data type, based on enum values:

                * :attr:`~hailo_sdk_client.exposed_definitions.InferenceDataType.auto` -- Automatically detection.
                * :attr:`~hailo_sdk_client.exposed_definitions.InferenceDataType.np_array` -- ``numpy.ndarray``,
                  or dictionary with input layer names as keys, and values types of ``numpy.ndarray``.
                * :attr:`~hailo_sdk_client.exposed_definitions.InferenceDataType.dataset` -- ``tensorflow.data.Dataset``
                  object with a valid signature. signature should be either ((h, w, c), image_info) or
                  ({'input_layer1': (h1, w1, c1), 'input_layer2': (h2, w2, c2)}, image_info) image_info can
                  be an empty dict for inference
                * :attr:`~hailo_sdk_client.exposed_definitions.InferenceDataType.npy_file` -- path to a npy or npz file.
                * :attr:`~hailo_sdk_client.exposed_definitions.InferenceDataType.npy_dir` -- path to a npy or npz dir,
                  assumes the same shape to all the items.
            data_count (optional, int): optional argument to limit the number of elements for analysis
            batch_size (optional, int): batch size for analysis
            analyze_mode (optional, str): selects the analyzing mode that will run simple or advanced.

        """
        self._analyze_noise(
            dataset=dataset,
            data_type=data_type,
            data_count=data_count,
            batch_size=batch_size,
            analyze_mode=analyze_mode,
            **kwargs,
        )

    @allowed_states(States.QUANTIZED_MODEL, States.COMPILED_MODEL)
    def _analyze_noise(
        self,
        dataset,
        data_type=InferenceDataType.auto,
        data_count: Optional[int] = None,
        batch_size: int = 1,
        analyze_mode: Optional[str] = None,
        **kwargs,
    ) -> None:
        if isinstance(dataset, dict):
            dataset = {self._hn.get_layer_by_name(key).name: value for key, value in dataset.items()}
        data, _ = data_to_dataset(dataset, data_type, self._logger)
        return self._sdk_backend.run_layer_analysis_tool(data, data_count, batch_size, analyze_mode, **kwargs)

    @allowed_states(States.HAILO_MODEL, States.FP_OPTIMIZED_MODEL, States.QUANTIZED_BASE_MODEL)
    def optimize(
        self,
        calib_data,
        data_type=CalibrationDataType.auto,
        *,
        work_dir=None,
        checkpoint: SupportedStops = SupportedStops.NONE,
        memento: Optional[FlowCheckPoint] = None,
    ):
        """
        Apply optimizations to the model:

            * Modify the network layers.
            * Quantize the model's params, using optional pre-process and post-process algorithms.

        Args:
            calib_data: Calibration data for Equalization and quantization process. The type depends on
                the ``data_type`` parameter.
            data_type (:class:`~hailo_sdk_client.exposed_definitions.CalibrationDataType`):
                calib_data's data type, based on enum values:

                * :attr:`~hailo_sdk_client.exposed_definitions.CalibrationDataType.auto` --
                  Automatically detected.
                * :attr:`~hailo_sdk_client.exposed_definitions.CalibrationDataType.np_array` --
                  ``numpy.ndarray``, or dictionary with input layer names as keys, and values types
                  of ``numpy.ndarray``.
                * :attr:`~hailo_sdk_client.exposed_definitions.CalibrationDataType.dataset` --
                  ``tensorflow.data.Dataset`` object with valid signature.
                  signature should be either ((h, w, c), image_info) or
                  ({'input_layer1': (h1, w1, c1), 'input_layer2': (h2, w2, c2)}, image_info)
                  image_info can be an empty dict for the quantization
                * :attr:`~hailo_sdk_client.exposed_definitions.CalibrationDataType.npy_file` --
                  path to a npy or npz file
                * :attr:`~hailo_sdk_client.exposed_definitions.CalibrationDataType.npy_dir` --
                  path to a npy or npz dir. Assumes the same shape for all the items

            work_dir (optional, str): If not None, dump quantization debug outputs to this directory.
            checkpoint (optional, :class:`~hailo_sdk_client.exposed_definitions.SupportedStops`):
                The optimization process will stop at the given checkpoint.
            memento (optional, :class:`~hailo_sdk_client.exposed_definitions.FlowCheckPoint`):
                The flow memento of the optimization process. this will be use to resume the optimization process.
        return:
            FlowCheckPoint: The flow memento of the optimization process.

        """
        result = self._optimize(
            calib_data, data_type=data_type, work_dir=work_dir, checkpoint=checkpoint, memento=memento
        )
        return result.flow_memento

    @allowed_states(States.COMPILED_MODEL)
    def get_hailo_runtime_model(self):
        """
        Generate model allowing to run the full ONNX graph using ONNX runtime, including the parts that are offloaded
        to the Hailo-8 (between the start and end nodes) and the parts that are not.
        """
        if not self._hailo_platform:
            raise HailoPlatformMissingException(
                "The HailoRT Python API (hailo_platform package) must be installed to run Hailo ONNX runtime model",
            )

        if self.original_model_meta is None or self.original_model_meta.get("framework") != str(NNFramework.ONNX):
            raise UnsupportedHailoRuntimeException("`get_hailo_runtime_model` supports only model parsed from ONNX")

        if self.original_model_meta.get("extractor_disabled", False):
            raise UnsupportedHailoRuntimeException(
                "ONNX metadata extraction was disabled, runtime model can't be "
                "generated. To get a runtime model please call "
                "`translate_onnx_model` with `disable_rt_metadata_extraction=False`",
            )

        if self.original_model_meta.get("errors"):
            self._logger.warning("Errors occurred when preparing Hailo ONNX runtime model")
            for error in self.original_model_meta["errors"]:
                self._logger.warning(error)

        if self.original_model_meta.get("is_supported_model", True) is False:
            raise UnsupportedHailoRuntimeException("The model is not supported, please check the previous warnings")

        if len(self.original_model_meta["preprocess_io_map"]) > 1:
            raise UnsupportedHailoRuntimeException("`get_hailo_runtime_model` supports only model with one input layer")

        if (self._preprocess_model or self._postprocess_model) and self._sdk_backend.is_modified_model():
            raise UnsupportedHailoRuntimeException(
                "The given model was modified (via models script commands), and is "
                "no longer equivalent to the original ONNX model. In this case, "
                "generating Hailo ONNX runtime model is not supported.",
            )

        if any(layer.op == LayerType.nms for layer in self._hn):
            raise UnsupportedHailoRuntimeException("NMS layers in Hailo ONNX runtime model are not supported yet")

        composer = HailoONNXModelComposer(
            self._hailo_platform,
            self._hef,
            self._preprocess_model,
            self._postprocess_model,
            self.original_model_meta,
            self._hn,
        )
        return composer.compose()

    @staticmethod
    def _is_numeric(target):
        for numeric_attribute in ["IS_NUMERIC", "is_numeric"]:
            if hasattr(target, numeric_attribute):
                return getattr(target, numeric_attribute)
        return True

    @allowed_states(States.HAILO_MODEL, States.FP_OPTIMIZED_MODEL, States.QUANTIZED_MODEL, States.COMPILED_MODEL)
    def save_parsing_report(self, report_path):
        """
        Save the parsing report to a given path.

        Args:
            report_path (string): Path to save the file.

        """
        if self.original_model_meta.get("parsing_report") is not None:
            self._logger.info(f"Saved parsing report to: {report_path}")
            with open(report_path, "w") as f:
                report = self.original_model_meta["parsing_report"]
                f.write(json.dumps(report.dict(), indent=4))
        else:
            self._logger.info("Parsing report doesn't exist")

    @allowed_states(States.HAILO_MODEL, States.FP_OPTIMIZED_MODEL, States.QUANTIZED_MODEL, States.COMPILED_MODEL)
    def get_detected_nms_config(self, meta_arch, config_path=None):
        """
        Get the detected NMS config file: anchors detected automatically from the model's post-process,
        and default values corresponding to the meta-architecture specified.

        Args:
            meta_arch (:class:`~hailo_sdk_common.hailo_nn.hn_definitions.NMSMetaArchitectures`):
                Meta architecture of the NMS post process.
            config_path (string, optional): Path to save the generated config file.
                Defaults to '{meta_arch}_nms_config.json'.

        """
        supported_archs = [
            NMSMetaArchitectures.YOLOV5,
            NMSMetaArchitectures.YOLOV5_SEG,
            NMSMetaArchitectures.YOLOX,
            NMSMetaArchitectures.YOLOV6,
        ]
        if meta_arch not in supported_archs:
            raise InvalidArgumentsException(f"{meta_arch.value} is currently not supported in this API.")

        if "detected_anchors" in self.original_model_meta:
            full_config = self._get_nms_full_config(meta_arch)
            config_path = config_path if config_path else f"{meta_arch.value}_nms_config.json"
            with open(config_path, "w") as f:
                f.write(json.dumps(full_config, indent=4))
            self._logger.info(
                "Saved NMS configuration template with the auto-detected anchors and typical default "
                f"parameters at {config_path}. Example for an alls command using it: "
                f"nms_postprocess(config_path={config_path}, meta_arch={meta_arch.value}).",
            )

    def _get_nms_full_config(self, meta_arch):
        default_path = SDKPaths().join_sdk_client(f"tools/core_postprocess/default_nms_config_{meta_arch.value}.json")
        with open(default_path) as f:
            default_config = json.loads(f.read())
        hailo_nn = self.get_hn_model()
        img_dims = hailo_nn.get_input_layers()[0].input_shape[1:3]
        detected_nms_layers = [
            hailo_nn.get_layer_by_original_name(name)
            for name in self.original_model_meta["detected_anchors"]["end_nodes"]
        ]
        classes = self._detect_num_of_classes(meta_arch, detected_nms_layers)

        full_config = dict(default_config.items())
        full_config["image_dims"] = img_dims
        full_config["classes"] = classes
        if meta_arch not in ANCHORLESS_YOLOS:
            bbox_decoders = []
            for anchor in self.original_model_meta["detected_anchors"]["info"].values():
                anchor.update({"name": f'bbox_decoder_{anchor["stride"]}'})
                bbox_decoders.append(anchor)
            full_config["bbox_decoders"] = bbox_decoders

        if self.original_model_meta["detected_anchors"].get("config_values"):
            full_config.update(self.original_model_meta["detected_anchors"]["config_values"])
        return full_config

    def _detect_num_of_classes(self, meta_arch, output_layers):
        if meta_arch == NMSMetaArchitectures.YOLOV5_SEG:
            seg_layer = next(layer for layer in output_layers if layer.activation == ActivationType.silu)
            conv_layer = next(layer for layer in output_layers if layer.activation != ActivationType.silu)
            classes = int(
                conv_layer.output_shape[3] / DEFAULT_YOLO_ANCHORS
                - seg_layer.output_shape[3]
                - DEFAULT_BOX_AND_OBJ_PXLS,
            )
        elif meta_arch == NMSMetaArchitectures.YOLOV5:
            classes = int(output_layers[0].output_shape[3] / DEFAULT_YOLO_ANCHORS - DEFAULT_BOX_AND_OBJ_PXLS)
        elif meta_arch in [NMSMetaArchitectures.YOLOX, NMSMetaArchitectures.YOLOV6]:
            if len(output_layers) == 1 and meta_arch == NMSMetaArchitectures.YOLOX:
                classes = output_layers[0].output_shape[2] - DEFAULT_BOX_AND_OBJ_PXLS
            elif len(output_layers) == 2 and meta_arch == NMSMetaArchitectures.YOLOV6:
                classes = max([x.output_shape[2] for x in output_layers])
            else:
                classes_nodes = output_layers
                if any(output_layer.activation == ActivationType.sigmoid for output_layer in output_layers):
                    classes_nodes = [x for x in output_layers if x.activation == ActivationType.sigmoid]
                classes = next(x.output_shape for x in classes_nodes if x.output_shapes[0][-1] != 1)[-1]
        elif meta_arch == NMSMetaArchitectures.YOLOV8:
            reg_length = self.original_model_meta["detected_anchors"]["config_values"]["regression_length"]
            reg_layer_f_out = reg_length * (DEFAULT_BOX_AND_OBJ_PXLS - 1)
            classes = next(x.output_shape[-1] for x in output_layers if x.output_shape[-1] != reg_layer_f_out)
        else:
            raise InvalidArgumentsException(f"Got unexpected meta_arch: {meta_arch.value}")

        return classes

    @allowed_states(States.QUANTIZED_MODEL)
    def init_lora_model(self, lora_weights_mapping):
        """
        Establish the LoRA model basic state.

        Args:
            lora_weights_mapping (str): A path for a json dictionary that maps between the LORA layer names and the corresponding
                weight variables to be added later from each new adapter.

        """
        if not self._sdk_backend.lora_weights_metadata:
            self._sdk_backend.init_lora_model(lora_weights_mapping)
        else:
            self._logger.info("Model was already initialized with LoRA weights mapping.")

    @allowed_states(States.QUANTIZED_MODEL)
    def load_lora_weights(self, lora_weights_path, lora_adapter_name):
        """
        Add LORA weights set (single adapter only) to a quantized Hailo model.

        Args:
            lora_weights_path (str): Path to the LORA weights file, in .safetensors format.
            lora_adapter_name (str): The name of the adapter representing the LoRA weights.

        """
        self._sdk_backend.load_lora_weights(lora_weights_path, lora_adapter_name)
        self._state = States.QUANTIZED_BASE_MODEL

        adapters = self._sdk_backend.get_lora_adapters()
        self._logger.info(
            f"Added LoRA weights for adapter {adapters[-1]}, " f"currently the model has {len(adapters)} adapters."
        )

    @property
    def use_service(self):
        self._logger.warning("use_service is an internal parameter and should not be used by the user.")
        return self._use_service

    @use_service.setter
    def use_service(self, use_service: bool):
        self._logger.warning("use_service is an internal parameter and should not be set by the user.")
        self._use_service = use_service

    @property
    def original_model_meta(self):
        return self._original_model_meta

    @original_model_meta.setter
    def original_model_meta(self, meta):
        self._original_model_meta = meta

    @property
    def original_model_path(self):
        return self._original_model_path

    @original_model_path.setter
    def original_model_path(self, path):
        self._original_model_path = path
