#!/usr/bin/env python
"""
This module represents sdk backend, its utils and its IPC wrapping.
the module will hold the ingredients to create a normal runner and our own tf graph.
All the assembling logic will be done here and the rest will be done by the client.
"""

import copy
import csv
import json
import os
import time
from contextlib import contextmanager
from typing import Tuple

import jsonschema
import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerTranslationConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_model import CheckerConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    MAX_NUM_REPEATS_ELTWISE,
    RECOMMENDED_COMP_LEVEL,
    RECOMMENDED_OPTM_LEVEL,
    ColorConversionType,
    DistributionStrategy,
    FeaturePolicy,
    FinetunePolicy,
    FormatConversionType,
    GPUAvailabilityMode,
    NpzExportMode,
    OptimizationTarget,
)
from hailo_model_optimization.acceleras.utils.dataset_util import DatasetContianer
from hailo_model_optimization.acceleras.utils.distributed_utils import (
    get_strategy,
    gpu_distributed_context,
)
from hailo_model_optimization.acceleras.utils.tf_utils import (
    get_gpu_availability_mode,
)
from hailo_model_optimization.algorithms.hailo_layer_noise_analysis import HailoQuantAnalyzer
from hailo_model_optimization.algorithms.lat_utils.lat_utils import AnalysisMode
from hailo_model_optimization.algorithms.mixed_precision.create_mixed_precision import CreateMixedPrecision
from hailo_model_optimization.flows.inference_flow import (
    HWInferenceModel,
    SimulationInferenceModel,
    SimulationTrainingModel,
)
from hailo_model_optimization.flows.optimization_flow import OptimizationFlow, SupportedStops

# TODO SDK-48985 FlowCheckpoint should not be Here
from hailo_model_optimization.tools.mo_script_parser import MOScriptParser, OptimizationFlavorsInfo
from hailo_model_optimization.tools.orchestator import FlowCheckPoint
from hailo_sdk_client.allocator.estimator import Estimator
from hailo_sdk_client.allocator.hailo_tools_runner import HailoToolsRunner
from hailo_sdk_client.allocator.hef_wrapper import HefWrapper
from hailo_sdk_client.emulator import model_factory
from hailo_sdk_client.emulator.model import OptionalModelParams
from hailo_sdk_client.exposed_definitions import InferenceContext
from hailo_sdk_client.numeric_translator import quantize_model, set_quantized_params
from hailo_sdk_client.paths_manager.build_dir_creator import BuildDirCreator
from hailo_sdk_client.paths_manager.sdk_runner_paths import ConfigPaths
from hailo_sdk_client.post_fuser.post_fuser import HailoNNPostFuser
from hailo_sdk_client.quantization.quantize import ModelOptimizer
from hailo_sdk_client.runner.exceptions import UnsupporteLoraAdapterException
from hailo_sdk_client.sdk_backend.modification_config import ModificationsConfig
from hailo_sdk_client.sdk_backend.profiler.profiler import Profiler
from hailo_sdk_client.sdk_backend.script_parser.commands import FormatConversionCommand, SupportedCommands
from hailo_sdk_client.sdk_backend.script_parser.input_conversion_commands import InputConversionCommand
from hailo_sdk_client.sdk_backend.script_parser.model_modifications_commands import (
    ResizeCommand,
    SetSeedCommand,
    TransposeCommand,
)
from hailo_sdk_client.sdk_backend.script_parser.model_script_parser import ModelScriptParser
from hailo_sdk_client.sdk_backend.script_parser.nms_postprocess_command import NMSPostprocessCommand
from hailo_sdk_client.sdk_backend.sdk_backend_data_class import CheckpointInfo, InferInfo, InternalContextInfo
from hailo_sdk_client.sdk_backend.sdk_backend_exceptions import (
    BackendRuntimeException,
    BackendValueError,
    ModelRunnerException,
)
from hailo_sdk_client.tools.calib_set_generator import get_random_calib_dataset
from hailo_sdk_client.tools.hailo_nn_tools import get_subgraphs
from hailo_sdk_client.tools.layers.lora_utils import load_lora_weights
from hailo_sdk_common.hailo_nn.hailo_nn import HailoNN, HNImporter
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers import FeatureMultiplierLayer
from hailo_sdk_common.hailo_nn.model_optimization.configuration_verifier import (
    apply_quantization_config_to_hn,
    verify_commands,
)
from hailo_sdk_common.logger.logger import DeprecationVersion, default_logger
from hailo_sdk_common.model_params.model_params import ModelParams
from hailo_sdk_common.paths_manager.paths import SDKPaths
from hailo_sdk_common.profiler.profiler_common import ProfilerModes
from hailo_sdk_common.targets.inference_targets import (
    EmulationInferenceTargets,
    FineTuneParams,
    ParamsKinds,
    SdkFPOptimized,
    SdkMixedParams,
    SdkNative,
)


class BackendChecksMessages:
    def __init__(self, backend: "SDKBackendQuantization"):
        self.backend = backend

    def nv12_optimization_message(self):
        # when the model is running on mercury arch and its input has 3 channels (YUV/RGB/BGR input format),
        # it is recommended to feed the network with NV12 format for better performance.
        if self.backend._hw_arch.name in self.backend._hw_arch.HAILO15_ARCHS:
            real_input_layers = self.backend._model.get_real_input_layers()
            if (
                len(real_input_layers) == 1
                and len(real_input_layers[0].input_shapes) == 1
                and real_input_layers[0].input_shape[-1] == 3
                and (
                    real_input_layers[0].op != LayerType.format_conversion
                    or (
                        real_input_layers[0].op == LayerType.format_conversion
                        and real_input_layers[0].conversion_type != FormatConversionType.nv12_to_hailo_yuv
                    )
                )
            ):
                self.backend._logger.info(
                    f"To optimize the performance in {self.backend._hw_arch.name} "
                    "device it is recommended to use NV12 input to the NN-core. "
                    "For example, compile the model with the following command: "
                    "reshape = input_conversion(input_layer1, nv12_to_hailo_yuv)"
                )

    def log_optimization_flavor_comments(self, flavor: OptimizationFlavorsInfo):
        num_of_parameters = self.backend._get_parameters_count()
        if (
            self.backend.hw_arch in self.backend.hw_arch.MERCURY_ARCHS
            and num_of_parameters > self.backend.hw_arch.HAILO15_LARGE_MODEL_PARAMS_TH
            and (
                flavor.optimization_level != RECOMMENDED_OPTM_LEVEL
                or flavor.compression_level != RECOMMENDED_COMP_LEVEL
            )
        ):
            if flavor.compression_level < 1 and flavor.optimization_level == RECOMMENDED_COMP_LEVEL:
                self.backend._logger.info(
                    f"For best throughput it is recommended setting the compression level to {RECOMMENDED_COMP_LEVEL}."
                )
            elif flavor.optimization_level < 1 and flavor.compression_level == RECOMMENDED_OPTM_LEVEL:
                self.backend._logger.warning(
                    f"For best accuracy results it is recommended setting the optimization level to {RECOMMENDED_OPTM_LEVEL}."
                )
            else:
                # the compression level and the optimization level are not set to the recommended values
                self.backend._logger.info(
                    "To obtain best performance for models with number "
                    f"of parameters larger than {self.hw_arch.HAILO15_LARGE_MODEL_PARAMS_TH}, "
                    f"it is recommended to use optimization level and compression level {RECOMMENDED_OPTM_LEVEL}."
                )

    def default_yuv_to_rgb_warning(self, command):
        if (
            self.backend._hw_arch in self.backend._hw_arch.HAILO15_ARCHS
            and command.function_name == SupportedCommands.INPUT_CONVERSION
            and command.conversion_type in [ColorConversionType.yuv_to_rgb, ColorConversionType.yuv_to_bgr]
        ):
            self.backend._logger.warning(
                "Be advised the Dataflow Compiler default YUV to RGB/BGR conversion uses YUV601 standard, which is different than ISP and may cause accuracy degradation."
                "For better results consider using yuv_full_range_to_rgb/bgr instead."
            )
            self.backend._logger.warning(
                "The Dataflow Compiler default YUV to RGB/BGR will be changed to yuv_full_range_to_rgb/bgr in the next version.",
            )


class SDKBackendCore:
    """
    @purpose: A wrapper for ModelRunner.
              This runner supposes to do all the parsing of data and params, and then
              build the graph of the model.
    """

    def __init__(self, hn, hw_arch, logger=None, alls_ignore_invalid_cmds=False):
        # Maintain the state of the model
        model = HailoNN.from_hn(json.dumps(hn))
        self._model = model
        self._native_model = model.from_parsed_hn(model.to_hn(model.name, json_dump=False))
        self._fp_model = None
        self._model_name = model.name
        self._hw_arch = hw_arch
        self._params = None
        self._params_fp_optimized = None
        self._params_hailo_optimized = None
        self._params_translated = None
        self._params_statistics = None
        self._logger = logger or default_logger()
        self._script_parser = ModelScriptParser(self._model, alls_ignore_invalid_cmds=alls_ignore_invalid_cmds)
        self._script_parser.sorting_disabled = True
        self._mo_flavor = None
        self._flavor_config = None
        self._har = None
        self._nms_metadata = None
        self._modifications_meta_data = ModificationsConfig()
        self._optimization_flow_memento: FlowCheckPoint = None
        self._lora_weights_metadata = None
        self._messages = BackendChecksMessages(self)

    def update_model(self, model: HailoNN):
        self._model = model
        self._script_parser.update_model(model)

    def update_from_hn(self, hn: dict):
        jsonschema.validate(hn, HNImporter()._load_schema())
        # Updating HN model
        model = HailoNN.from_hn(json.dumps(hn))
        self.update_model(model)

    def update_native_model(self, model):
        self._native_model = model

    def update_fp_model(self, model):
        self._fp_model = model

    def get_params(self):
        """
        Get the native (non quantized) params the runner uses
        """
        if self._params is None:
            return None
        if isinstance(self._params, ModelParams):
            return ModelParams(self._params)
        return dict(self._params)

    def get_hn_dict(self):
        model_copy = self.model.from_parsed_hn(self.model.to_hn(self.model.name, json_dump=False))
        return model_copy.to_hn(self.model.name, json_dump=False)

    def get_hn_native_dict(self):
        model_copy = self.native_model.from_parsed_hn(self.native_model.to_hn(self.native_model.name, json_dump=False))
        return model_copy.to_hn(self.native_model.name, json_dump=False)

    def get_hn_fp_dict(self):
        model_copy = self.fp_model.from_parsed_hn(self.fp_model.to_hn(self.fp_model.name, json_dump=False))
        return model_copy.to_hn(self.fp_model.name, json_dump=False)

    def get_params_fp_optimized(self):
        if self._params_fp_optimized is None:
            return None
        if isinstance(self._params_fp_optimized, ModelParams):
            return ModelParams(self._params_fp_optimized)
        return dict(self._params_fp_optimized)

    def get_params_hailo_optimized(self):
        if self._params_hailo_optimized is None:
            return None
        if isinstance(self._params_hailo_optimized, ModelParams):
            return ModelParams(self._params_hailo_optimized)
        return dict(self._params_hailo_optimized)

    def get_params_translated(self):
        """
        Get the quantized params the runner uses
        """
        if self._params_translated is None:
            return None
        if isinstance(self._params_translated, ModelParams):
            return ModelParams(self._params_translated)
        return dict(self._params_translated)

    def get_params_statistics(self):
        """
        Get Statistics parameters about the model
        """
        if self._params_statistics is None:
            return None
        if isinstance(self._params_statistics, ModelParams):
            return ModelParams(self._params_statistics)
        return dict(self._params_statistics)

    def load_params(self, params, params_kind=None):
        if params_kind == ParamsKinds.NATIVE_FUSED_BN:
            params_kind = ParamsKinds.NATIVE

        if params is None:
            loaded_params_kind = params_kind if params_kind is not None else ParamsKinds.NATIVE
            params_kind = loaded_params_kind  # prevent legacy flow, and make sure verification succeeds
            params_to_load = None
        else:
            params_to_load = ModelParams(params)
            loaded_params_kind = params_to_load.params_kind_enum

        if loaded_params_kind is not None:
            if loaded_params_kind == ParamsKinds.NATIVE_FUSED_BN:
                loaded_params_kind = ParamsKinds.NATIVE
            self._verify_params_kind(params_kind, loaded_params_kind)
            params_kind = loaded_params_kind
        else:  # Legacy npz
            if params_kind is None:
                params_kind = ParamsKinds.NATIVE
            params_to_load.set_params_kind(params_kind)

        if params_to_load:
            if not self.fp_model and params_kind in [ParamsKinds.FP_OPTIMIZED, ParamsKinds.TRANSLATED]:
                self.update_fp_model(self._model)

            self._validate_params(params_to_load, params_kind)

        if params_kind == ParamsKinds.NATIVE:
            self._params = params_to_load
        elif params_kind == ParamsKinds.TRANSLATED:
            self._params_translated = params_to_load
        elif params_kind == ParamsKinds.FP_OPTIMIZED:
            self._params_fp_optimized = params_to_load
        elif params_kind == ParamsKinds.HAILO_OPTIMIZED:
            self._params_hailo_optimized = params_to_load
        elif params_kind == ParamsKinds.STATISTICS:
            self._params_statistics = params_to_load
        else:
            raise BackendValueError(f"Bad params kind: {params_kind}")

        return params_kind

    def _are_params_required(self, params_kind):
        if params_kind in [ParamsKinds.NATIVE, ParamsKinds.NATIVE_FUSED_BN]:
            return self._native_model.requires_native_weights
        elif params_kind == ParamsKinds.FP_OPTIMIZED:
            return self._fp_model.requires_native_weights
        elif params_kind == ParamsKinds.HAILO_OPTIMIZED:
            return self._model.requires_native_weights
        elif params_kind == ParamsKinds.TRANSLATED:
            return self._model.requires_quantized_weights
        elif params_kind == ParamsKinds.STATISTICS:
            return False
        else:
            raise BackendValueError(f"Invalid ParamsKinds received: {params_kind}")

    def _validate_params(self, model_params, params_kind):
        # Check the received ModelParams
        params_to_check = [
            "kernel",
            "gamma",
            "epsilon",
            "beta",
            "moving_mean",
            "moving_variance",
            "bias",
            "leaky_alpha",
            "activation_threshold",
            "activation_delta_bias",
            "swish_beta",
            "activation_less_values",
            "hardsigmoid_alpha",
            "hardsigmoid_beta",
            "clip_min",
            "clip_max",
            "activation_greater_values",
        ]
        validators = {"power_table": FeatureMultiplierLayer.validate_table}

        keys = [param_key for param_key in model_params.keys() if param_key.split("/")[-1][:-2] in params_to_check]

        not_good_params = []
        for key in keys:
            param = model_params[key]
            if isinstance(param, list):
                for p in param:
                    if np.any(np.isnan(p)):
                        not_good_params.append(key)
            elif np.any(np.isnan(param)):
                not_good_params.append(key)

        if len(not_good_params) > 0:
            raise BackendValueError(f"Unsupported NaN values were found in params {not_good_params}")

        keys_params = [
            (param_key, param_key.split("/")[-1][:-2])
            for param_key in model_params.keys()
            if param_key.split("/")[-1][:-2] in validators
        ]

        for key, param in keys_params:
            try:
                validators[param](model_params[key])
            except Exception as e:
                layer_name = key.split("/")[1]
                raise BackendValueError(f"Bad value of param {key} in layer {layer_name}: {e!s}") from e

        params_received = model_params and len(model_params.params) > 1
        params_required = self._are_params_required(params_kind)

        if params_required and not params_received:
            raise BackendRuntimeException(f"{params_kind} params are required and were not received")

        elif not params_required and params_received and params_kind != ParamsKinds.STATISTICS:
            self._logger.warning(f"{params_kind} params were received even though they are not required")

    @staticmethod
    def _verify_params_kind(params_kind, loaded_params_kind):
        if (params_kind is not None) and (loaded_params_kind != params_kind):
            raise BackendValueError(
                f"Given params to load seem to be {params_kind}, so they can't be loaded as {loaded_params_kind}.",
            )

    @property
    def model_name(self):
        return self._model.name

    @property
    def model(self):
        return self._model

    @property
    def native_model(self):
        return self._native_model

    @property
    def fp_model(self):
        return self._fp_model

    @property
    def hw_arch(self):
        return self._hw_arch

    @property
    def requires_native_weights(self):
        return self._model.requires_native_weights

    @property
    def requires_quantized_weights(self):
        return self._model.requires_quantized_weights

    @property
    def model_script(self):
        if len(self._script_parser.commands) > 0:
            return str(self._script_parser)

    @property
    def modifications_meta_data(self):
        return self._modifications_meta_data

    @modifications_meta_data.setter
    def modifications_meta_data(self, meta_data):
        self._modifications_meta_data = meta_data

    @property
    def lora_weights_metadata(self):
        return self._lora_weights_metadata

    @lora_weights_metadata.setter
    def lora_weights_metadata(self, lora_weights_metadata):
        self._lora_weights_metadata = lora_weights_metadata

    @classmethod
    def from_hn_json(cls, hn, *args, **kws):
        """
        Get a runner instance from an existing Hailo model.
        """
        model = HailoNN.from_hn(hn)
        return cls(model, *args, **kws)

    @classmethod
    def from_hn_path(cls, hn_path, *args, **kws):
        """
        Get a runner instance from an existing Hailo model.
        """
        with open(hn_path) as hn:
            model = HailoNN.from_fp(hn)
        return cls(model, *args, **kws)

    @classmethod
    def from_model(cls, model, *args, **kws):
        """
        Get a runner instance from an existing Hailo model.
        """
        return cls(model, *args, **kws)

    def model_script_commands(self):
        return self._script_parser.commands

    def load_model_script(self, model_script, append=False):
        if model_script is None:
            model_script = ""
        nms_config = self._nms_metadata.config_file if self._nms_metadata else None
        self._script_parser.parse_script(model_script, append, nms_config)
        # Prevent contradiction with allocation scripts once quantized har is already loaded
        self._har = None

    def load_model_script_from_file(self, model_script_path, append=False):
        nms_config = self._nms_metadata.config_file if self._nms_metadata else None
        self._script_parser.parse_script_from_file(model_script_path, nms_config, append)
        # Prevent contradiction with allocation scripts once quantized har is already loaded
        self._har = None

    def load_model_script_from_har(self, har):
        self._script_parser.parse_script_from_har(har)

    def reapply_alls_commands_on_load_har(self, hn, params):
        reapply_commands_type = [SetSeedCommand]
        reapply_commands = [
            command for command in self._script_parser.commands if type(command) in reapply_commands_type
        ]

        for command in reapply_commands:
            command.apply(hn, params)

    def init_lora_model(self, lora_weights_mapping):
        if any(len(hn.net_params.net_scopes) != 1 for hn in [self._native_model, self._fp_model, self._model]):
            raise UnsupporteLoraAdapterException("LoRA is not supported for models with multiple scopes")
        with open(lora_weights_mapping, "r") as f:
            self._lora_weights_metadata = json.load(f)

        self._native_model.net_params.lora_adapters = [self._native_model.net_params.net_scopes[0]]
        self._fp_model.net_params.lora_adapters = [self._fp_model.net_params.net_scopes[0]]
        self._model.net_params.lora_adapters = [self._model.net_params.net_scopes[0]]

    def load_lora_weights(self, lora_weights_file, lora_adapter_name):
        native_model, native_params, new_lora_metadata = load_lora_weights(
            model=self._native_model,
            params=self._params,
            lora_layers_metadata=self._lora_weights_metadata,
            lora_weights_file=lora_weights_file,
            lora_adapter_name=lora_adapter_name,
        )
        self.update_native_model(native_model)
        self.load_params(native_params, params_kind=ParamsKinds.NATIVE)
        if new_lora_metadata:
            self.lora_weights_metadata = new_lora_metadata.copy()

        fp_model, fp_params, _ = load_lora_weights(
            model=self._fp_model,
            params=self._params_fp_optimized,
            lora_layers_metadata=self._lora_weights_metadata,
            lora_weights_file=lora_weights_file,
            lora_adapter_name=lora_adapter_name,
        )
        self.update_fp_model(fp_model)
        self.load_params(fp_params, params_kind=ParamsKinds.FP_OPTIMIZED)

        model, ho_params, _ = load_lora_weights(
            model=self._model,
            params=self._params_hailo_optimized,
            lora_layers_metadata=self._lora_weights_metadata,
            lora_weights_file=lora_weights_file,
            lora_adapter_name=lora_adapter_name,
            is_acceleras_params=True,
            log_info=True,
        )
        self.update_model(model)
        self.load_params(ho_params, params_kind=ParamsKinds.HAILO_OPTIMIZED)

    @staticmethod
    def _is_numeric(target):
        for numeric_attribute in ["IS_NUMERIC", "is_numeric"]:
            if hasattr(target, numeric_attribute):
                return getattr(target, numeric_attribute)
        return True

    @contextmanager
    def override_nms_score_threshold(self, nms_score_threshold):
        if self._nms_metadata and nms_score_threshold is not None:
            orig_nms_scores_th = self._nms_metadata.nms_config.nms_scores_th
            self._nms_metadata.nms_config.nms_scores_th = nms_score_threshold
            yield
            self._nms_metadata.nms_config.nms_scores_th = orig_nms_scores_th
        else:
            yield


class SDKBackendGraph(SDKBackendCore):
    def __init__(self, hn, hw_arch, alls_ignore_invalid_cmds=False):
        super().__init__(hn, hw_arch, alls_ignore_invalid_cmds=alls_ignore_invalid_cmds)
        self._models_history = {}
        self._executable_model = None
        self._last_used_params = None
        self._force_weightless_model = False

    def target_to_model(self, target):
        targets = {
            EmulationInferenceTargets.SDK_NATIVE: self._native_model,
            EmulationInferenceTargets.SDK_FP_OPTIMIZED: self._fp_model,
        }
        return targets.get(target.name, self._model)

    @property
    def force_weightless_model(self):
        return self._force_weightless_model

    @force_weightless_model.setter
    def force_weightless_model(self, force_weightless_model):
        if not isinstance(force_weightless_model, bool):
            raise TypeError("force_weightless_model has to be bool")
        self._force_weightless_model = force_weightless_model

    @property
    def executable_model(self):
        return self._executable_model

    def load_params(self, params, params_kind=None):
        if self.force_weightless_model:
            # This model was forced to work without weights
            self._logger.debug("Skipping params loading due to weightless mode")
            return None
        return super().load_params(params, params_kind=params_kind)

    def _get_tf_graph(
        self,
        target,
        nodes=None,
        rescale_output=None,
        translate_input=None,
        custom_session=None,
        twin_mode=False,
        activation_callback=None,
        native_layers=None,
        run_numeric_in_int32=False,  # How do we use it?
    ):
        """
        Calculate the TF graph for the given target and optional set of parameters.

        Args:
            target: emulation target.
            nodes: input layer_name mapped to tensors to which the emulator graph will be connected.
            rescale_output: should the emulator rescale the output from the output_tensors.
            translate_input: should the emulator translate the input from the preprocessing_graph.
            custom_session: custom session on which the parameters will be assigned in our emulator,
                None for a new session.
            twin_mode: set to True to rename the TF graph, this way you can call this function multiple times
                with same session.
            activation_callback: pointer to activation function
            native_layers: when using mixed mode; layers to run in native mode.
            run_numeric_in_int32: Use int32 for numeric inference
        Returns:
            a :obj:`HailoGraphExport` object allowing the user to integrate the emulator's graph in a custom
            tensorflow environment. The export object contains tensors lists that the user can evaluate for different
            purposes, for example calibrations stats tensors used in the quantization process.

        """
        if not isinstance(nodes, dict) and nodes is not None:
            nodes = {self._model.get_input_layers()[0].name: nodes}
        elif nodes is None:
            nodes = {}
        elif isinstance(nodes, dict) and nodes is not None:
            nodes = {self._model.get_layer_by_name(input_name).name: tensor for input_name, tensor in nodes.items()}

        if translate_input is None:
            translate_input = self._is_numeric(target)
        if rescale_output is None:
            rescale_output = self._is_numeric(target)

        if twin_mode and isinstance(target, (SdkNative, SdkFPOptimized)):
            # msg: Twin mode changes the names of the tensors in the emulation graph, and
            # the quantization algorithm expects the names without this modification. If twin mode
            # is used not for quantization, then there is no problem.
            self._logger.warning(
                "Please note that you cannot collect statistics for weights calibration "
                "(quantization) when using twin mode.",
            )

        if nodes:
            graph = next(iter(nodes.values())).graph
        elif custom_session is not None:
            graph = custom_session.graph
        else:
            graph = None  # Create graph inside tf model

        should_translate_input = self._should_rescale(target, translate_input)
        executable_model, subgraph_executable_models, number_of_submodels = self._get_executable_model(
            target,
            should_translate_input,
            custom_graph=graph,
            custom_session=custom_session,
            twin_mode=twin_mode,
            activation_callback=activation_callback,
            native_layers=native_layers,
            run_numeric_in_int32=run_numeric_in_int32,
            custom_inputs=nodes,
        )

        if not twin_mode:
            self._executable_model = executable_model

        # In mixed mode, output is *always* rescaled to float32 from uin8
        if target.name == EmulationInferenceTargets.SDK_MIXED and rescale_output:
            raise ModelRunnerException(
                f"rescale_output cannot be set to True when using {EmulationInferenceTargets.SDK_MIXED} mode.",
            )

        hailo_export = executable_model.prepare_full_hailo_graph_export(
            self._last_used_params,
            self._should_rescale(target, rescale_output),
            self.requires_quantized_weights,
        )
        self._finalize_tf_graph(hailo_export, target, rescale_output)

        def exports():
            for name, model in subgraph_executable_models:
                subgraph_hailo_export = model.prepare_full_hailo_graph_export(
                    self._last_used_params,
                    self._should_rescale(target, rescale_output),
                    self.requires_quantized_weights,
                )
                self._finalize_tf_graph(subgraph_hailo_export, target, rescale_output)
                yield name, subgraph_hailo_export

        return hailo_export, exports(), number_of_submodels

    def _get_executable_model(
        self,
        target,
        translate_input=False,
        custom_graph=None,
        custom_session=None,
        twin_mode=False,
        activation_callback=None,
        native_layers=None,
        run_numeric_in_int32=False,
        custom_inputs=None,
    ):
        if activation_callback is not None and target.name not in [
            EmulationInferenceTargets.SDK_NATIVE,
            EmulationInferenceTargets.SDK_FINE_TUNE,
            EmulationInferenceTargets.SDK_FP_OPTIMIZED,
        ]:
            raise ModelRunnerException(
                "Statistics collection using get_tf_graph must be executed before calling "
                "translate_params, and from the same runner instance.",
            )

        # TODO: known limitations of twin mode:
        # 1. graph used for calibration must not be created with twin mode
        # 2. when working with client-server API, server model caching limits the user to one model per target
        executable_model_suffix = self._get_executable_model_suffix(target.name, twin_mode)

        def load_params_func():
            return self._get_executable_params(executable_model_suffix, target, native_layers)

        params = load_params_func()
        translated_consts = self._get_translated_consts(executable_model_suffix, target)
        self._last_used_params = params
        consts = params.get_consts() if params is not None else None
        activation_points = params.get_points_count() if params is not None else None
        fine_tune_params = target.fine_tune_params if hasattr(target, "fine_tune_params") else FineTuneParams()
        mixed_params = target.mixed_params if hasattr(target, "mixed_params") else SdkMixedParams()
        is_mercury_arch = self._hw_arch.is_mercury_arch if self._hw_arch is not None else False
        is_pluto_arch = self._hw_arch.is_pluto_arch if self._hw_arch is not None else False
        enable_clipping = False
        if target.name == EmulationInferenceTargets.SDK_FP_OPTIMIZED:
            enable_clipping = target.enable_clipping
        optional_model_params = OptionalModelParams(
            translate_input=translate_input,
            model_name=executable_model_suffix,
            custom_graph=custom_graph,
            custom_session=custom_session,
            custom_inputs=custom_inputs,
            consts=consts,
            translated_consts=translated_consts,
            activation_callback=activation_callback,
            native_layers=native_layers,
            fine_tune_params=fine_tune_params,
            mixed_params=mixed_params,
            is_mercury_arch=is_mercury_arch,
            is_pluto_arch=is_pluto_arch,
            force_weightless_model=self._force_weightless_model,
            run_numeric_in_int32=run_numeric_in_int32,
            activation_points=activation_points,
            twin_mode=twin_mode,
            enable_clipping=enable_clipping,
        )

        del params

        model = self.target_to_model(target)

        hailo_executable_model = model_factory.create_model(model, target.name, optional_model_params)
        hailo_executable_model.set_load_params_func(load_params_func)

        subgraphs = get_subgraphs(model)

        def models():
            for name, _model in subgraphs.items():
                model_params = optional_model_params._replace(
                    custom_graph=None,
                    custom_inputs=None,
                    custom_session=None,
                )
                executable_model = model_factory.create_model(_model, target.name, model_params)
                executable_model.set_load_params_func(load_params_func)
                yield name, executable_model

        return hailo_executable_model, models(), len(subgraphs)

    def _get_executable_params(self, executable_model_suffix, target, native_layers=None):
        is_numeric = target.name in [
            EmulationInferenceTargets.SDK_NUMERIC,
            EmulationInferenceTargets.SDK_DEBUG_PRECISE_NUMERIC,
            EmulationInferenceTargets.SDK_PARTIAL_NUMERIC,
        ]

        if target.name == EmulationInferenceTargets.SDK_NATIVE:
            params = self._params
        elif target.name == EmulationInferenceTargets.SDK_MIXED:
            params = self._get_mixed_params(native_layers)
        elif is_numeric:
            params = self._params_translated
        elif target.name in [EmulationInferenceTargets.SDK_FP_OPTIMIZED, EmulationInferenceTargets.SDK_FINE_TUNE]:
            params = self._params_fp_optimized if self._params_fp_optimized else self._params
        else:
            params = self._params
        if self._force_weightless_model:
            # The user has chosen to receive the SW graph without parameters
            return None

        if not params:
            if is_numeric and self.requires_quantized_weights:
                raise ModelRunnerException(f"Runner must have quantized params for running in {target.name} mode")
            elif self.requires_native_weights:
                raise ModelRunnerException(f"Runner must have native params for running in {target.name} mode")
        return self._get_renamed_params(params, executable_model_suffix)

    def _get_translated_consts(self, executable_model_suffix, target):
        translated_consts = None
        if target.name == EmulationInferenceTargets.SDK_FINE_TUNE:
            if self._params_translated is None:
                raise ModelRunnerException("You must quantize the params before running fine tune target")
            params = self._get_renamed_params(self._params_translated, executable_model_suffix)
            translated_consts = params.get_consts()
        return translated_consts

    @staticmethod
    def _get_renamed_params(params, executable_model_suffix):
        if executable_model_suffix:
            params = ModelParams(params.params, executable_model_suffix=executable_model_suffix)
        return params

    def _get_executable_model_suffix(self, target_name, twin_mode):
        if not twin_mode:
            return ""
        if target_name not in self._models_history:
            self._models_history[target_name] = 0
        result = f"{target_name}_{self._models_history[target_name]}"
        self._models_history[target_name] += 1
        return result

    def _should_rescale(self, target, rescale):
        return (
            rescale
            and self._params_translated is not None
            and target.name
            not in [
                EmulationInferenceTargets.SDK_NATIVE,
                EmulationInferenceTargets.SDK_FINE_TUNE,
                EmulationInferenceTargets.SDK_FP_OPTIMIZED,
            ]
        )

    def _get_mixed_params(self, native_layers):
        result_params = {}
        if native_layers is None:
            native_layers = []
        params = self._params if self._params_fp_optimized is None else self._params_fp_optimized
        if params is None:
            raise ValueError("Missing native/optimized params for Mixed mode")
        if self._params_translated is None:
            raise ValueError("Missing numeric params for Mixed mode")
        for layer in self._model:
            if layer.name in native_layers:
                if layer.name not in params.layers:
                    continue
                for k, v in params[layer.name].items():
                    result_params[f"{layer.name}/{k}"] = v
                # This is a somewhat hacky way to ensure that layers that have more than one qp_in (EW-add and EW-mult)
                # can get the parameters from their predecessors. Note that we just add to all the layers qp_out and
                # limvals out
                self._add_quantized_params_needed_for_mixed_mode(layer.name, result_params)
            else:
                for k, v in self._params_translated[layer.name].items():
                    result_params[f"{layer.name}/{k}"] = v
        return ModelParams(result_params)

    def _add_quantized_params_needed_for_mixed_mode(self, full_layer_name, result_params):
        qp_out_key = f"{full_layer_name}/qp_out:0"
        limvals_out_key = f"{full_layer_name}/limvals_out:0"
        result_params[qp_out_key] = self._params_translated[qp_out_key]
        result_params[limvals_out_key] = self._params_translated[limvals_out_key]

    def _finalize_tf_graph(self, hailo_export, target, rescale_output):
        session = hailo_export.session
        with session.as_default(), session.graph.as_default():
            self._handle_unset_variables(hailo_export.session, hailo_export.unset_variables)
            assign_ops = [op for op in hailo_export.graph.get_operations() if op.name.endswith("hailo_params_assign")]
            session.run(assign_ops)

        hailo_export.rescale_output = rescale_output
        model = self.target_to_model(target)

        hn_dict = model.to_hn(model.name, json_dump=False)
        hailo_export.update_original_names(hn_dict)

        return hailo_export

    def _handle_unset_variables(self, sess, unset_variables):
        """
        Only initializes the variables of a TensorFlow session that were not already initialized.
        These are exported via the variables export level UNSET_VARIABLES.

        Args:
            sess (tf.Session): Tensorflow session that may contain uninitialized variables.
            unset_variables (list of :obj:`tf.Variable`): Tensorflow variables that need
                initialization.

        """
        if unset_variables:
            is_initialized = sess.run([tf.compat.v1.is_variable_initialized(var) for var in unset_variables])
            unset_variables_res = [var for (var, init) in zip(unset_variables, is_initialized) if not init]
            if len(unset_variables_res) > 0:
                # msg: Using TF initializers is usually OK, for example when running NATIVE_CLIPPED
                # emulation. But in rare cases it can also indicate that variables (like weights)
                # are initialized randomly instead of with their real values.
                self._logger.debug("Found unset variables in model export, running predefined TF initializers.")
                sess.run(tf.compat.v1.variables_initializer(unset_variables_res))


class SDKBackendQuantization(SDKBackendGraph):
    def __init__(self, hn, hw_arch, work_dir=None, logger=None, alls_ignore_invalid_cmds=False):
        super().__init__(hn, hw_arch, alls_ignore_invalid_cmds=alls_ignore_invalid_cmds)
        self._mo_commands = {}
        self._model_optimizer = ModelOptimizer(self._model, self._hw_arch, self._get_tf_graph, work_dir)
        self._logger = logger or default_logger()
        self._flow_results = []
        self._inference_flow: SimulationInferenceModel = None
        self._inference_context = None
        self._default_compression_level = None
        self._default_optimization_level = None
        self._calibration_data = None
        self._calibration_data_random_max = None
        self._optimization_target = self.get_optimization_target()

    @property
    def optimization_target(self):
        return self._optimization_target

    def get_params_pre_quantization(self, *, copy_params=False):
        params = self.get_params_fp_optimized()
        if params is not None:
            new_params = copy.deepcopy(params) if copy_params else params
            return new_params
        params = self.get_params()
        if params is not None:
            params = copy.deepcopy(params)
            params.set_params_kind(ParamsKinds.FP_OPTIMIZED)
            self.load_params(params)
            new_params = copy.deepcopy(params) if copy_params else params
            return new_params
        return params

    def get_lora_adapters(self):
        return self._model.net_params.lora_adapters

    def pre_quantization_structural(self):
        params = self.get_params_pre_quantization()
        if params is None:
            return

        new_params = self._model_optimizer.sort_params(params, update_graph=False)
        if self._model_optimizer.structural_change:
            new_hn = self._model_optimizer.hn_model
            self.update_model(new_hn)
            self.load_params(new_params, ParamsKinds.FP_OPTIMIZED)
            self._model_optimizer.structural_change = False

    def update_model(self, model, override_config=False):
        super().update_model(model)
        self._model_optimizer.update_model(model, override_config)

    def pre_quantization_optimization(self, clip_aware_sort=False):
        self.init_stats_collection()
        params = self.get_params_pre_quantization()
        if params is None:
            return None
        model_optimizer = self._model_optimizer
        mo_config = model_optimizer.get_config()
        if params is not None:
            # save params after_bn for debug?
            model_optimizer._hailo_np_savez(f"{model_optimizer._hn_model.name}_after_bn.npz", params)
        params = self._model_optimizer.equalize(params, self.executable_model)
        params = self._model_optimizer.sort_params(params, clip_aware_sort)
        params = self._model_optimizer.clip_weights(params, mo_config.weights_clipping)
        params = self._model_optimizer.clip_activations(params, mo_config.activation_clipping)
        self.save_and_load_pre_params(params)
        return params

    def core_quantization(self, force_results_by_layer=None, previous_statistics=None, force_params=None):
        params = self.get_params_pre_quantization() if force_params is None else force_params
        translated = self._model_optimizer.bit_reduction(
            params,
            self.executable_model,
            force_results_by_layer=force_results_by_layer,
            previous_statistics=previous_statistics,
        )
        self.save_and_load_quantize_params(translated)
        return translated

    def post_quantization_optimization(self, force_results_by_layer=None):
        params = self._model_optimizer.adaround(self.get_params_pre_quantization(), self._params_translated)

        if params is not None:
            translated = self._model_optimizer.bit_reduction(params, self.executable_model, reuse_stats=True)
            self.load_params(translated, ParamsKinds.TRANSLATED)
        start_time = time.time()
        params = self._model_optimizer.bias_correction(
            self.update_params_layer_bias,
            self.get_params_pre_quantization(),
            self._params_translated,
            self.executable_model,
        )
        end_time = time.time()
        self._model_optimizer.print_algo_time(start_time, end_time, "ibc")
        if params is not None:
            self.load_params(params, params_kind=ParamsKinds.TRANSLATED)
        self._model_optimizer.finetune(self, force_results_by_layer=force_results_by_layer)

    def build_acceleras_model(self, context: InferenceContext, lora_adapter_name=None) -> HailoModel:
        nms_config = self.nms_metadata.nms_config.to_post_config() if self.nms_metadata else None
        if context == InferenceContext.SDK_NATIVE:
            hn_data = self.get_hn_native_dict()
            npz = self.get_params()
            qnpz = None
        elif context == InferenceContext.SDK_FP_OPTIMIZED:
            hn_data = self.get_hn_fp_dict()
            npz = self.get_params_pre_quantization()
            qnpz = None
            base_model = self._fp_model
        elif context in [InferenceContext.SDK_QUANTIZED, InferenceContext.SDK_BIT_EXACT]:
            hn_data = self.get_hn_dict()
            npz = self.get_params_hailo_optimized()
            qnpz = self.get_params_translated()
            base_model = self._model
        else:
            raise ValueError(f"Invalid emulation context {context}")

        if lora_adapter_name is not None:
            adapters = hn_data["net_params"].get("lora_adapters", [])
            if lora_adapter_name not in adapters:
                raise UnsupporteLoraAdapterException(
                    f"LoRA adapter {lora_adapter_name} could not be found in the model. Available LoRA adapters: {adapters}."
                )

        acceleras_model = HailoModel(
            hn_data,
            nms_config=nms_config,
            optimization_target=self.optimization_target,
            lora_adapter_name=lora_adapter_name,
        )

        mode = NpzExportMode(npz.get("mode", NpzExportMode.WEIGHTS))
        if mode == NpzExportMode.ACCELERAS:
            acceleras_model.import_acceleras(npz)
        elif mode == NpzExportMode.WEIGHTS:
            acceleras_model.import_weights(npz)
            if context not in [InferenceContext.SDK_NATIVE, InferenceContext.SDK_FP_OPTIMIZED]:
                layers_defaults = {}
                layers_defaults.update(base_model.get_per_layer_precision_config())
                layers_defaults.update(base_model.get_per_layer_translation_config())
                layers_defaults.setdefault("precision_config", {})["target"] = self.optimization_target
                layers_cfg = ModelOptimizationConfig(**layers_defaults)
                CreateMixedPrecision(
                    model=acceleras_model,
                    model_config=layers_cfg,
                    logger_level=0,
                    for_infer=True,
                ).run()
            if qnpz is not None:
                acceleras_model.import_hw_params_from_qnpz(qnpz)
        else:
            raise ValueError("QNPZ params mode is not supported in accleras emulator")

        return acceleras_model

    def apply_quantization_script(self, data=None, data_type=None):
        # needs the initialized dataset
        return self._apply_quantization_script_with_flavor()

    def _apply_quantization_script_with_flavor(self, flavor_config=None, mo_flavor=None):
        """
        Parses and creates a mo config dict from the model script file
        It also applies configurations to the hn

        Returns
            dict with mo config

        """
        # Expand the glob syntax and validate the commands
        self._script_parser.update_model_optimization_commands()

        # Export model optimization configuration from the commands in script parser
        mo_commands = self.model_optimization_commands

        # Verify the commands and clean invalid glob commands
        config = verify_commands(self._model, mo_commands, flavor_config)

        # Reload all model script commands after verification
        # self._script_parser.reload_mo_commands(config, exclude_defaults=False)

        # Apply only quantization groups from the config to the model.
        # Required for params sorter. should be removed in the near future
        apply_quantization_config_to_hn(self._model, config)

        # Export the reloaded commands from the script
        return self.model_optimization_commands

    def _get_parameters_count(self):
        """
        Calculate the number of parameters of the model

        Returns
            number of parameters

        """
        params = self.get_params()
        if params is not None:
            return sum([np.array(x).size for x in list(params.values())])
        else:
            return 0

    def set_default_optimization_flavor(self, compression_level=0, optimization_level=0):
        self._default_compression_level = compression_level
        self._default_optimization_level = optimization_level

    def setup_quantization(self, data_continer: DatasetContianer, work_dir=None, config=None):
        if work_dir is not None:
            self._model_optimizer.work_dir = work_dir

        if config is None:
            config = self.apply_quantization_script(data_continer)
        else:
            self._logger.warning("Ignoring quantization model script commands, using data from API")

        self._model_optimizer.load_config(config)

    def run_layer_analysis_tool(self, data, data_count, batch_size, analyze_mode, **kwargs):
        # Create copy of model_config so given arguments will overwrite configuration without changing the alls
        analyze_mode = AnalysisMode(analyze_mode) if analyze_mode is not None else AnalysisMode.advanced

        config = self._model_optimizer.get_config().copy(
            update={
                "checker_cfg": CheckerConfig(
                    policy=FeaturePolicy.enabled,
                    dataset_size=data_count,
                    batch_size=batch_size,
                    analyze_mode=analyze_mode,
                ),
            },
        )

        model = self.build_acceleras_model(InferenceContext.SDK_QUANTIZED)
        if kwargs.get("use_optimize_model", False):
            native_model = self.build_acceleras_model(InferenceContext.SDK_QUANTIZED)
            native_model.set_native()
        else:
            native_model = self.build_acceleras_model(InferenceContext.SDK_FP_OPTIMIZED)

        analyzer = HailoQuantAnalyzer(
            model=model,
            model_config=config,
            unbatched_data_set=data,
            native_model=native_model,
        )
        analyzer.run()
        statistics = analyzer.get_statistics()
        if statistics is not None:
            params_statistics = self.get_params_statistics()
            params_statistics = params_statistics if params_statistics is not None else {}
            params_statistics.update(statistics)
            self.load_params(params_statistics, ParamsKinds.STATISTICS)

        return analyzer.analysis_results

    def full_quantization(
        self,
        data_continer: DatasetContianer,
        *,
        work_dir=None,
        checkpoint_info: CheckpointInfo = CheckpointInfo(),
    ) -> CheckpointInfo:
        """Run the full quantization flow"""

        self._logger.important("Starting Model Optimization")
        self.update_fp_model(self._model)
        self._messages.nv12_optimization_message()
        self.setup_quantization(data_continer, work_dir=work_dir)
        self.pre_quantization_structural()
        new_checkpoint_info = self._full_acceleras_run(
            data_continer, work_dir=work_dir, checkpoint_info=checkpoint_info
        )

        if new_checkpoint_info.quantization_done:
            self._logger.verbose("Core and post Quantization is done with Acceleras")
            self._finalize_quantization()
            self._logger.important("Model Optimization is done")

        return new_checkpoint_info

    def lora_quantization(
        self,
        adapter_name: str,
        data_continer: DatasetContianer,
        *,
        work_dir=None,
        checkpoint_info: CheckpointInfo = CheckpointInfo(),
    ) -> CheckpointInfo:
        # TODO: SDK-57131 generalize optimization_flow to support LoRA optimization. Then remove this exception.
        raise NotImplementedError("LoRA optimization is currently WIP and not supported at the moment.")

        adapters = self.get_lora_adapters()
        if adapter_name not in adapters:
            raise UnsupporteLoraAdapterException(
                f"LoRA adapter {adapter_name} could not be found in the model. Available LoRA adapters: {adapters}."
            )
        self._logger.important(f"Starting LoRA {adapter_name} Optimization")

        self.setup_quantization(data_continer, work_dir=work_dir)
        new_checkpoint_info = self._full_acceleras_run(
            data_continer, adapter_name=adapter_name, work_dir=work_dir, checkpoint_info=checkpoint_info
        )
        if new_checkpoint_info.quantization_done:
            self._logger.verbose("Core and post Quantization is done with Acceleras")
            self._finalize_quantization()
            self._logger.important(f"LoRA {adapter_name} Optimization is done")
        return new_checkpoint_info

    def save_and_load_pre_params(self, params):
        self.load_params(params, ParamsKinds.FP_OPTIMIZED)
        model_optimizer = self._model_optimizer
        if params is not None:
            model_name = model_optimizer._hn_model.name
            model_optimizer._hailo_np_savez(f"{model_name}_pre.npz", params)

    def save_and_load_params_hailo_optimized(self, params):
        self.load_params(params, ParamsKinds.HAILO_OPTIMIZED)
        model_optimizer = self._model_optimizer
        if params is not None:
            model_name = model_optimizer._hn_model.name
            model_optimizer._hailo_np_savez(f"{model_name}_optimized.npz", params)

    def save_and_load_quantize_params(self, translated):
        """Load and save quantization params"""
        self.load_params(translated, ParamsKinds.TRANSLATED)
        model_optimizer = self._model_optimizer

        if model_optimizer._work_dir is not None:
            hn_str = model_optimizer._hn_model.to_hn(model_optimizer._hn_model.name)
            hn_path = os.path.join(model_optimizer._work_dir, f"{model_optimizer._hn_model.name}.q.hn")
            with open(hn_path, "w") as fp:
                fp.write(hn_str)
        model_optimizer._hailo_np_savez(f"{model_optimizer._hn_model.name}_quant.npz", translated)

    @staticmethod
    def is_supported_in_acceleras(translation_config):
        """
        check if the translation config is supported in acceleras -  (all params are defaulted expect for null factors)

        Args:
            translation_config:

        Returns: True if it is

        """
        other_as_dict = LayerTranslationConfig.get_default().raw_dict(False, False)
        translation_config_as_dict = translation_config.raw_dict(False, False)
        for key in other_as_dict:
            if key == "null_channels_cutoff_factor":
                continue
            if other_as_dict[key] != translation_config_as_dict[key]:
                return False
        return True

    def hw_inference(self, infer_info: InferInfo):
        self._inference_context = infer_info.context_info.infer_context
        self._inference_flow = HWInferenceModel(infer_info.context_info.graph_export)
        self._inference_flow.compile(run_eagerly=infer_info.run_eagerly)
        return self._inference_flow.run(infer_info.data, infer_info.batch_size, infer_info.data_count)

    def get_emulation_model(self, context_info: InternalContextInfo, trainable: bool):
        acceleras_model = self.build_acceleras_model(
            context=context_info.infer_context,
            lora_adapter_name=context_info.lora_adapter_name,
        )

        model = SimulationTrainingModel(acceleras_model) if trainable else SimulationInferenceModel(acceleras_model)
        if context_info.infer_context == InferenceContext.SDK_QUANTIZED:
            model.set_quantized()
            if context_info.flow_commands:
                model.custom_infer_config(context_info.flow_commands)
        else:
            model.set_native()

        return model

    @staticmethod
    def get_hw_model(hailo_export, run_eagerly=False):
        model = HWInferenceModel(hailo_export)
        model.compile(run_eagerly=run_eagerly)
        return model

    def set_emulation_model(self, model, slim_mode=False):
        acceleras_model = model._model
        try:
            if not slim_mode:
                fp_after_optimization = acceleras_model.export_acceleras()
            params_exported = acceleras_model.export_hw_params()
        except Exception as e:
            raise BackendRuntimeException(f"Failed to set model: {e}") from e
        if not slim_mode:
            self.load_params(fp_after_optimization, ParamsKinds.HAILO_OPTIMIZED)
        self.load_params(params_exported, ParamsKinds.TRANSLATED)
        self._inference_flow = None  # remove old flow

    def acceleras_inference(self, infer_info: InferInfo):
        context = infer_info.context_info.infer_context
        lora_adapter_name = infer_info.context_info.lora_adapter_name

        gpu_info = get_gpu_availability_mode()
        # TODO: Changing gpu availability outside will be change https://hailotech.atlassian.net/browse/SDK-49012
        gpu_info.gpu_availability = GPUAvailabilityMode.NOT_IN_USE
        gpu_strategy = get_strategy(
            gpu_policy=infer_info.context_info.gpu_policy,
            default_gpu_policy=DistributionStrategy.DATA_P,
            supported_gpu_info={DistributionStrategy.DATA_P, DistributionStrategy.MODEL_P},
            gpu_info=gpu_info,
        )
        if gpu_info.num_gpus > 0:
            self._logger.info(f"Using {gpu_info.num_gpus} GPU for inference")

        if (
            self._inference_flow is not None
            and self._inference_context == context
            and self._inference_flow.dist_info.dist_strategy == gpu_strategy.dist_strategy
        ):
            self._logger.verbose("Reusing model")
        else:
            with gpu_distributed_context(gpu_strategy) as context_info:
                model = self.build_acceleras_model(context, lora_adapter_name=lora_adapter_name)
                model.dist_info = context_info
                self._inference_flow = SimulationInferenceModel(model)
                self._inference_flow.dist_info = context_info
                self._inference_context = context
                if context in [InferenceContext.SDK_QUANTIZED, InferenceContext.SDK_BIT_EXACT]:
                    self._inference_flow.set_quantized()
                    if context == InferenceContext.SDK_BIT_EXACT:
                        self._inference_flow.set_native(False)
                        self._inference_flow.set_bit_exact(True)
                    else:
                        self._inference_flow.set_bit_exact(False)

                    if infer_info.context_info.flow_commands:
                        self._inference_flow.custom_infer_config(infer_info.context_info.flow_commands)
                else:
                    self._inference_flow.set_native(True)
        with gpu_distributed_context(gpu_strategy) as context_info:
            if gpu_strategy is DistributionStrategy.DATA_P:
                batch_size = infer_info.batch_size * gpu_info.num_gpus
            else:
                batch_size = infer_info.batch_size
            return self._inference_flow.run(infer_info.data, batch_size, infer_info.data_count)

    def _checkpoint_update(
        self, checkpoint_info: CheckpointInfo, mo_config: ModelOptimizationConfig
    ) -> Tuple[CheckpointInfo, ModelOptimizationConfig]:
        if checkpoint_info.run_until == SupportedStops.DEEQUALIZE:
            mo_config.globals.deequalize = FeaturePolicy.enabled
            return checkpoint_info, mo_config
        deequalize = mo_config.globals.deequalize == FeaturePolicy.enabled
        if deequalize:
            checkpoint_info.run_until = SupportedStops.DEEQUALIZE

        return checkpoint_info, mo_config

    def _full_acceleras_run(
        self,
        data_continer: DatasetContianer,
        *,
        adapter_name=None,
        work_dir=None,
        checkpoint_info: CheckpointInfo = CheckpointInfo(),
    ) -> CheckpointInfo:
        # Comon Variables
        work_dir = self._model_optimizer.work_dir if work_dir is None else work_dir  # this need to be removed
        num_of_parameters = self._get_parameters_count()

        # Saving model config
        model_config = self._model_optimizer.get_config()
        self._model_optimizer._config.save_cfg(
            work_dir,
            model_config,
            "loaded_config.yaml",
        )
        # Creating Mo Config
        layers_defaults = {
            **self._model.get_per_layer_precision_config(),
            **self._model.get_per_layer_translation_config(),
        }

        parser = MOScriptParser(
            self._script_parser.original_script,
            layers_defaults,
            data_continer,
            num_of_parameters,
            self._logger,
            self.optimization_target,
        )
        mo_config = parser.run()
        self._validate_commands(mo_config)
        self._messages.log_optimization_flavor_comments(parser.results)
        checkpoint_info, mo_config = self._checkpoint_update(checkpoint_info, mo_config)

        # Optimization Flow
        optimization_flow = OptimizationFlow(
            self.get_hn_dict(),
            self.get_params_pre_quantization(copy_params=True),
            mo_config,
            data_continer,
            self.optimization_target,
            logger=self._logger,
            work_dir=work_dir,
            nms_config=self._nms_metadata.nms_config.to_post_config() if self._nms_metadata else None,
            params_statistics=self.get_params_statistics(),
            adapter_name=adapter_name,
        )

        new_checkpoint_info = self._optimization_flow_runner(optimization_flow, checkpoint_info)

        # Saving Params and HN
        if new_checkpoint_info.quantization_done:
            if adapter_name is None:
                acceleras_params = {}
                quant_params = {}
            else:
                acceleras_params = self.get_params_hailo_optimized()
                quant_params = self.get_params_translated()
            acceleras_params.update(optimization_flow.get_acceleras_params())
            quant_params.update(optimization_flow.get_quant_params())
            self.save_and_load_params_hailo_optimized(acceleras_params)
            self.save_and_load_quantize_params(quant_params)
            self.load_params(optimization_flow.params_statistics, ParamsKinds.STATISTICS)
            self.update_from_hn(optimization_flow.get_hn())
            self.modifications_meta_data.tracker.update(optimization_flow.modifications_meta_data)

            # Updating matadata
            self.mo_flavor = parser.results
            self._flow_results = optimization_flow.flow_results

        # Saving Params and HN after_equalization

        elif checkpoint_info.run_until == SupportedStops.DEEQUALIZE:
            fp_params = self._update_deequalize_fp_params(optimization_flow)
            self.load_params(fp_params, ParamsKinds.FP_OPTIMIZED)
            self._logger.important("Model equalize is done")

        return new_checkpoint_info

    def get_optimization_target(self) -> OptimizationTarget:
        if self.hw_arch.name in self.hw_arch.SAGE_B0_ARCHS:
            return OptimizationTarget.SAGE
        elif (self.hw_arch.name in self.hw_arch.MERCURY_ARCHS) and (self.hw_arch.name not in self.hw_arch.PLUTO_ARCHS):
            return OptimizationTarget.MERCURY
        elif self.hw_arch.name in self.hw_arch.PLUTO_ARCHS:
            return OptimizationTarget.PLUTO
        elif self.hw_arch.name in self.hw_arch.MARS_ARCHS:
            return OptimizationTarget.MARS
        else:
            raise BackendRuntimeException(f"Unsupported HW arch {self.hw_arch.name}")

    def get_flow_memento(self) -> FlowCheckPoint:
        return self._optimization_flow_memento

    def set_flow_memento(self, memento: FlowCheckPoint) -> None:
        self._optimization_flow_memento = memento

    def _finalize_quantization(self):
        self._model_optimizer.update_model(self.model, override_config=True)
        self._model_optimizer.save_config()

    def _update_deequalize_fp_params(self, optimization_flow):
        fp_params = self.get_params_fp_optimized().params
        deequalize_params = optimization_flow.get_deequalize_params()

        fp_params.update(deequalize_params)
        return fp_params

    def init_stats_collection(self):
        hn_data = self.get_hn_dict()
        params_data = self.get_params_pre_quantization()
        self._model_optimizer.init_stats_collection(hn_data, params_data)

    @property
    def model_optimization_commands(self):
        return self._script_parser.export_model_optimization_commands()

    @property
    def model_modifications_commands(self):
        return self._script_parser.export_model_modifications_commands()

    def update_params_layer_bias(self, bias_diff, layer):
        layer_name = layer.name.split("/")[0] + "/" + layer.name.split("/")[1]

        conv_layer_inference_item = None
        if hasattr(self, "_executable_model") and hasattr(self._executable_model, "conv_layers_inference"):
            conv_layer_inference_item = self._executable_model.conv_layers_inference[layer_name]
        new_bias_params = quantize_model.update_translated_params_layer_bias(
            self._params_translated,
            layer_name,
            bias_diff,
            conv_layer_inference_item,
        )

        self._params_translated.update(new_bias_params)

        return self.get_params_translated()

    def translate_params_by_inference_results(
        self,
        inference_results,
        previous_statistics=None,
        debug_precise_mode=False,
        max_elementwise_feed_repeat=MAX_NUM_REPEATS_ELTWISE,
    ):
        """
        Quantize the model params.

        Args:
            inference_results: a dict from the layer name to the inference results of that layer (from the user).
                Can be created by running get_results_by_layer on the inference results, and the output requests.
            previous_statistics: previous layer statistics used for scale matching.
            debug_precise_mode: run quantization in a higher precision for emulator code debug.
            max_elementwise_feed_repeat: Max value of elementwise feed repeat, used for calculating the
                quantized representation of biases and elementwise-add.

        Returns:
            rescaled params + scales, ready to be loaded by the model simulator.

        """
        # TODO: https://hailotech.atlassian.net/browse/SDK-51152
        self._logger.deprecation_warning(
            "translate_params_by_inference_results will be removed in the near future, please use "
            "ClientRunner.optimize instead",
            DeprecationVersion.JAN2022,
        )

        self.apply_quantization_script()

        # Update conv_layers_inference with the inference_results
        if self._executable_model is None:
            raise ModelRunnerException(
                "Statistics collection using get_tf_graph must be executed before calling "
                "translate_params, and from the same runner instance.",
            )
        self._executable_model.get_finalized_conv_layers_inference(inference_results)

        # Translate the params with the updated conv layers inference
        params = self.get_params_pre_quantization()

        # Collect statistics from the model
        statistics = quantize_model.quantize_model(
            self._hw_arch,
            params,
            self._executable_model.conv_layers_inference,
            self._model,
            previous_statistics,
            debug_precise_mode,
            is_apu_2s_complement=True,
            max_elementwise_feed_repeat=max_elementwise_feed_repeat,
        )

        # Update params with quantization results
        self._params_translated = set_quantized_params.set_quantized_params(params, statistics)

        return self._params_translated.params

    def _validate_commands(self, mo_commands: ModelOptimizationConfig):
        has_set_seed = any(isinstance(a, SetSeedCommand) for a in self.model_modifications_commands)
        if has_set_seed and (mo_commands.finetune.policy == FinetunePolicy.enabled):
            raise BackendRuntimeException(
                "Can't reproduce model optimization results while 'finetune' algorithm is "
                "applied. Please use a different optimization level or remove 'set_seed' "
                "from the model script",
            )
        postprocess_layers_names = [layer.name for layer in self._model.nodes if layer.op == LayerType.postprocess]
        if len(postprocess_layers_names) > 0 and (mo_commands.globals.output_encoding_vector == FeaturePolicy.enabled):
            raise BackendRuntimeException(
                "Output encoding vector is not supported on models that use "
                "HailoRT-Post Processing capabilities, which are required for supporting "
                f"the layers {postprocess_layers_names}. Please either disable output vector encoding vector, "
                "or attempt to remove those layers from the model.",
            )

    def _apply_model_modification_commands(self, model, params, update_model_and_params):
        model_modification_commands = list(self.model_modifications_commands)

        has_nms = False
        if model_modification_commands:
            self._validate_modification_commands_order(model_modification_commands, model)
            for command in model_modification_commands:
                command.validate_command([layer.name for layer in model])
                self._messages.default_yuv_to_rgb_warning(command)

                model, params = command.apply(model, params, hw_consts=self.hw_arch.consts)
                if update_model_and_params:
                    for layer_name in command.meta_data:
                        self._add_layer_to_modifications_meta_data(layer_name, command.meta_data[layer_name], model)
                if isinstance(command, NMSPostprocessCommand):
                    self._nms_metadata = command.export_nms_metadata()
                    has_nms = True

        if not has_nms:
            # in case a config file was generated while parsing but nms is not added.
            self._nms_metadata = None

        return self._reload_model_and_params(model, params, update_model_and_params)

    def _validate_modification_commands_order(self, commands, model):
        conversion_types = [FormatConversionCommand, InputConversionCommand]
        conversion_cmds = [(i, cmd) for i, cmd in enumerate(commands) if type(cmd) in conversion_types]
        if not conversion_cmds or conversion_cmds[-1][0] == len(commands) - 1:
            return

        if all(x[1].conversion_type in ColorConversionType for x in conversion_cmds):
            return

        first_idx = conversion_cmds[-1][0] + 1
        transposes = [x for x in commands[first_idx:] if isinstance(x, TransposeCommand)]
        resizes = [x for x in commands[first_idx:] if isinstance(x, ResizeCommand)]
        resize_names = []
        for resize_cmd in resizes:
            resize_names.extend(resize_cmd.get_layers())
        input_names = [x.name for x in model.get_input_layers()]

        # resize_name is None == resizing input layer
        if transposes or any(name is None or name in input_names for name in resize_names):
            raise BackendRuntimeException(
                f"Can't add {commands[-1].function_name} layer before an input conversion, "
                "as it modifies the tensor format. Please verify the model script commands "
                "are in correct order (no transpose or resize layers before input "
                "conversion layers)."
            )

    def _add_layer_to_modifications_meta_data(self, lname, config, model):
        if lname in self.modifications_meta_data.inputs:
            self.modifications_meta_data.inputs[lname].insert(0, config)
        elif lname in self.modifications_meta_data.outputs:
            self.modifications_meta_data.outputs[lname].insert(0, config)
        elif lname in [x.name for x in model.get_input_layers()]:
            self.modifications_meta_data.inputs[lname] = [config]
        else:
            self.modifications_meta_data.outputs[lname] = [config]

    @property
    def nms_metadata(self):
        return self._nms_metadata

    @nms_metadata.setter
    def nms_metadata(self, nms_metadata):
        self._nms_metadata = nms_metadata

    @property
    def har(self):
        return self._har

    @har.setter
    def har(self, har):
        self._har = har

    @property
    def mo_flavor(self):
        return self._mo_flavor

    @mo_flavor.setter
    def mo_flavor(self, mo_flavor):
        self._mo_flavor = mo_flavor

    @property
    def flavor_config(self):
        return self._flavor_config

    @flavor_config.setter
    def flavor_config(self, flavor_config):
        self._flavor_config = flavor_config

    @property
    def calibration_data(self):
        return self._calibration_data

    @calibration_data.setter
    def calibration_data(self, calibration_data):
        self._calibration_data = calibration_data

    @property
    def calibration_data_random_max(self):
        return self._calibration_data_random_max

    @calibration_data_random_max.setter
    def calibration_data_random_max(self, calibration_data_random_max):
        self._calibration_data_random_max = calibration_data_random_max

    def optimize_full_precision(
        self, update_model_and_params=True, data_continer: DatasetContianer = None
    ):  # nned to fix this
        data_continer = data_continer if data_continer is not None else DatasetContianer
        self._logger.debug("Full-precision optimization stage starting.")

        model = HailoNN.from_parsed_hn(self._model.to_hn(self._model.name, json_dump=False))
        params = self.get_params_fp_optimized()
        if not params:
            params = self.get_params()
            if not params:
                params = ModelParams({})
            params.set_params_kind(ParamsKinds.FP_OPTIMIZED)

        model, params = self._apply_model_modification_commands(model, params, update_model_and_params)
        data_continer = self.update_or_create_calib_data(data_continer)
        model, params = self._run_post_fuser(model, params, update_model_and_params)

        # flow = optimization_flow()
        # flow.run(until="step0")
        # model = flow.model
        # params =flow.native()

        # memento = flow.call_history

        # flow.run(memento)
        # TODO: SDK-48237
        # 1. init fp optimization flow object
        # 2. setup acceleras with calib data
        # 3. run structural algos
        # TODO need to get params , hn
        # 4. update fp optimized model and weights

        self._logger.debug("Full-precision optimization stage is done.")

        return model, params

    def update_or_create_calib_data(self, data_continer: DatasetContianer) -> DatasetContianer:
        if data_continer.data is not None:
            self.calibration_data = data_continer
            return data_continer
        if self.calibration_data is not None:
            return self.calibration_data

        else:
            self._logger.debug("Use random calibration data for full precision optimization")
            self.calibration_data_random_max = (
                self.calibration_data_random_max if self.calibration_data_random_max else 1
            )
            calib_data = get_random_calib_dataset(
                self._model,
                self._calibration_data_random_max,
            )
            data_continer = DatasetContianer(calib_data)
            self.calibration_data = data_continer
        return data_continer

    def _reload_model_and_params(
        self, model, params, update_model_and_params, params_statistics=None, modifications_meta_data=None
    ):
        model = HailoNN.from_parsed_hn(model.to_hn(self.model_name, json_dump=False))
        params = ModelParams(params.params)

        if update_model_and_params:
            self.update_model(model, override_config=True)
            self.update_fp_model(model)
            self.load_params(params)
            if params_statistics:
                self.load_params(params_statistics)
            if modifications_meta_data:
                self.modifications_meta_data.tracker.update(modifications_meta_data)

        return model, params

    def _run_post_fuser(self, fp_optimized_model, fp_optimized_params, update_model_and_params):
        self._update_misc_params_to_model(fp_optimized_model)

        post_fuser_commands = self._script_parser.export_post_fuser_commands()
        config = verify_commands(self._model, post_fuser_commands, pre_quantization_mode=True)

        params_statistics = self.get_params_statistics()
        if not params_statistics:
            params_statistics = ModelParams({})
            params_statistics.set_params_kind(ParamsKinds.STATISTICS)

        post_fuser = HailoNNPostFuser(fp_optimized_model, fp_optimized_params, params_statistics, config, self._hw_arch)
        post_fuser.run()

        return self._reload_model_and_params(
            post_fuser.model,
            post_fuser.params,
            update_model_and_params,
            post_fuser.params_statistics,
            post_fuser.modifications_meta_data,
        )

    def _update_misc_params_to_model(self, fp_optimized_model):
        # TODO: this behavior uses old assignment approach. as long as we don't add any validation, it should be fine.
        # TODO: Some would call this code a hack, and I agree
        quantization_param_commands = self._script_parser.export_quantization_param_commands()

        # TODO: https://hailotech.atlassian.net/browse/SDK-34802
        keys = ["null_channels_cutoff_factor"]
        layers = [layer.name for layer in fp_optimized_model]
        for cmd in quantization_param_commands:
            quantization_param = {key: cmd.quantization_params[key] for key in keys if key in cmd.quantization_params}
            if not quantization_param:
                continue
            cmd.expand_glob(layers, fp_optimized_model.net_params.net_scopes)
            cmd.validate_command(layers)
            for layer in cmd.input_layers:
                for key in quantization_param:
                    setattr(
                        fp_optimized_model.get_layer_by_name(layer).translation_config,
                        key,
                        quantization_param[key],
                    )

    def is_modified_model(self):
        for layer in self._model.get_input_layers():
            if layer.conversion_type is not None:
                return True

        for command in self.model_modifications_commands:
            if command.function_name != SupportedCommands.TRANSPOSE:
                return True

        return False


class SdkBackendCompilation(SDKBackendQuantization):
    def __init__(self, hn, hw_arch, alls_ignore_invalid_cmds=False):
        super().__init__(hn, hw_arch, alls_ignore_invalid_cmds=alls_ignore_invalid_cmds)
        self._hef_data = None
        self._integrated_graph = None
        self._auto_alls = None
        self._alls_ignore_invalid_cmds = alls_ignore_invalid_cmds

    def _prepare_build(self):
        build_dir_creator = BuildDirCreator(self._hw_arch)
        build_dir_creator.prepare_build_dir()

    @staticmethod
    def _write_jlf_to_path(path, data):
        with open(path, "wb") as f:
            f.write(data)

    def _delete_build_files(self):
        config_paths = ConfigPaths(self._hw_arch, self._model.name)
        config_paths.set_stage("inference")
        file_to_remove = [
            config_paths.get_path("network_graph"),
            config_paths.get_path("mapped_graph"),
            config_paths.get_path("compilation_output_proto"),
            config_paths.get_path("compiler_statistics"),
        ]
        for file in file_to_remove:
            if os.path.isfile(file):
                os.remove(file)

    def hef_full_build(self, fps, mapping_timeout, params, allocator_script):
        self._prepare_build()
        allocator = HailoToolsRunner(
            self._hw_arch,
            self._model,
            fps=fps,
            clk_freq=self._hw_arch.clk_freq,
            timeout=mapping_timeout,
        )

        config_paths = ConfigPaths(self._hw_arch, self._model.name)
        config_paths.set_stage("inference")
        auto_alls, self._hef_data, self._integrated_graph = allocator.create_mapping_and_full_build_hef(
            config_paths.get_path("network_graph"),
            config_paths.get_path("mapped_graph"),
            config_paths.get_path("compilation_output_proto"),
            params=params,
            allocator_script=allocator_script,
            compiler_statistics_path=config_paths.get_path("compiler_statistics"),
            nms_metadata=self._nms_metadata,
            har=self.har,
            alls_ignore_invalid_cmds=self._alls_ignore_invalid_cmds,
        )

        return self._hef_data, config_paths.get_path("mapped_graph"), auto_alls

    def _compile(self, fps, allocator_script=None, mapping_timeout=None):
        # model_params is a copy of the sdk_backend's translated params, since they may get defused
        # during hte compilation process
        model_params = ModelParams(self.get_params_translated()) if self.get_params_translated() else None
        if not model_params and self.requires_quantized_weights:
            raise BackendRuntimeException(
                "Model requires quantized weights in order to run on HW, but none were given. "
                "Did you forget to quantize?",
            )

        hef, mapped_graph_file, auto_alls = self.hef_full_build(fps, mapping_timeout, model_params, allocator_script)
        self._auto_alls = auto_alls
        return hef, mapped_graph_file

    def compile(self, fps, allocator_script=None, mapping_timeout=None):
        self._model.fill_default_quantization_params(logger=self._logger)
        hef, mapped_graph_file = self._compile(fps, allocator_script, mapping_timeout)
        # TODO: https://hailotech.atlassian.net/browse/SDK-31038
        if not SDKPaths().is_internal:
            self._delete_build_files()

        return hef

    def profile(self, should_use_logical_layers=True, debug=False, hef=None, runtime_data=None, stream_fps=None):
        profiling_mode_name = "Parsed"
        if self._params_translated is not None:
            profiling_mode_name = "Optimized"
        if hef is not None:
            profiling_mode_name = "Compiled"

        export, hn_model, params = self._run_profiler(profiling_mode_name)

        if hef is not None:
            estimator = self._run_estimator(
                allocator_script=None,
                debug=debug,
                hef=hef,
                hn_model=hn_model,
                params=params,
                profiling_mode=ProfilerModes.POST_PLACEMENT,
                runtime_data=runtime_data,
                should_use_logical_layers=should_use_logical_layers,
                stream_fps=stream_fps,
                accuracy_data=export["accuracy_data"],
            )
            estimator.create_log()
            stats = estimator.get_stats()
            export["stats"]["model_details"].update(stats["model_details"])
            export["stats"]["performance_details"] = stats["performance_details"]
            estimator.create_csv("estimation.csv")
            with open("estimation.csv") as csv_file:
                export["csv_data"] = list(csv.reader(csv_file))
            export["latency_data"] = estimator.get_latency_data()
        return export

    def _run_estimator(
        self,
        allocator_script,
        debug,
        hef,
        hn_model,
        params,
        profiling_mode,
        runtime_data,
        should_use_logical_layers,
        stream_fps,
        accuracy_data,
    ):
        if profiling_mode is ProfilerModes.POST_PLACEMENT:
            estimator = self._estimate_from_existing_allocation(
                debug=debug,
                flavor_config=self.flavor_config,
                hef=hef,
                mo_flavor=self.mo_flavor,
                model=hn_model,
                params=params,
                profiling_mode=profiling_mode,
                runtime_data=runtime_data,
                should_use_logical_layers=should_use_logical_layers,
                stream_fps=stream_fps,
                accuracy_data=accuracy_data,
            )
        else:
            allocator = HailoToolsRunner(hw_arch=self._hw_arch, model=hn_model, clk_freq=self._hw_arch.clk_freq)
            estimator, auto_alls = allocator.run_estimator(
                "network.pb",
                "mapped_graph.pb",
                profiling_mode,
                hn_model,
                should_use_logical_layers=should_use_logical_layers,
                allocator_script=allocator_script,
                translated_params=self._params_translated,
                script_parser=self._script_parser,
                flavor_config=self.flavor_config,
                mo_flavor=self.mo_flavor,
                params=params,
                har=self.har,
                stream_fps=stream_fps,
                accuracy_data=accuracy_data,
            )
            self._auto_alls = auto_alls
        self._delete_build_files()
        return estimator

    def _run_profiler(self, profiling_mode_name):
        is_pre_quant = self._params_translated is None
        hn_model, params = self.model, self._params_fp_optimized

        # Create a copy to avoid defaults in the runner's model
        hn_model = HailoNN.from_parsed_hn(hn_model.to_hn(hn_model.name, json_dump=False))
        hn_model.fill_default_quantization_params(disable_warning=is_pre_quant)
        optimization_commands = self._script_parser.export_model_optimization_commands()

        profiler = Profiler(
            profiling_mode=profiling_mode_name,
            hw_arch=self.hw_arch,
            hn=hn_model,
            params=params,
            translated_params=self._params_translated,
            hailo_optimized_params=self._params_hailo_optimized,
            statistics_params=self._params_statistics,
            modifications_meta_data=self._modifications_meta_data,
            optimization_commands=optimization_commands,
            mo_flavor=self.mo_flavor,
            flavor_config=self.flavor_config,
        )
        export = profiler.get_export()

        return export, hn_model, params

    def _optimize_model_for_profiler(self, apply_fp_optimization):
        if apply_fp_optimization:
            if self._params is None and self.requires_native_weights:
                self._logger.warning(
                    "The model is not optimized and native params are not given. "
                    "Skipping full-precision optimization.",
                )
                hn_model, params = self.model, self._params_fp_optimized
            else:
                hn_model, params = self.optimize_full_precision(update_model_and_params=False)
        else:
            hn_model, params = self.model, self._params_fp_optimized

        return hn_model, params

    def _estimate_from_existing_allocation(
        self,
        debug,
        flavor_config,
        hef,
        mo_flavor,
        model,
        params,
        profiling_mode,
        runtime_data,
        should_use_logical_layers,
        stream_fps,
        accuracy_data,
    ):
        hef_proto = HefWrapper(hef).hef_proto
        return Estimator(
            self._hw_arch,
            hef_proto.mapping,
            model,
            self._hw_arch.clk_freq,
            profiling_mode,
            model,
            should_use_logical_layers=should_use_logical_layers,
            translated_params=self._params_translated,
            debug=debug,
            runtime_data=runtime_data,
            hef_proto=hef_proto,
            script_parser=self._script_parser,
            flavor_config=flavor_config,
            mo_flavor=mo_flavor,
            params=params,
            stream_fps=stream_fps,
            accuracy_data=accuracy_data,
        )

    def get_mapping(self):
        return self._integrated_graph

    def get_auto_alls(self):
        return self._auto_alls

    @staticmethod
    def _optimization_flow_runner(
        optimization_flow: OptimizationFlow,
        checkpoint_info: CheckpointInfo,
    ) -> CheckpointInfo:
        """This only exist to keep those if on check, flow control"""

        if checkpoint_info.run_until == SupportedStops.NONE and checkpoint_info.flow_memento is None:
            optimization_flow.run()
            new_checkpoint_info = CheckpointInfo(quantization_done=True)

        elif checkpoint_info.flow_memento is not None:
            optimization_flow.run(run_until=checkpoint_info.run_until.value, memento=checkpoint_info.flow_memento)
            finish_quantization = checkpoint_info.run_until == SupportedStops.NONE
            new_checkpoint_info = CheckpointInfo(
                quantization_done=finish_quantization, flow_memento=optimization_flow.call_history
            )
        else:
            optimization_flow.run(run_until=checkpoint_info.run_until.value)
            memento = optimization_flow.call_history
            new_checkpoint_info = CheckpointInfo(flow_memento=memento, quantization_done=False)
        return new_checkpoint_info


class SDKBackend(SdkBackendCompilation):
    pass
