import copy
import fnmatch
import json
from collections import OrderedDict
from functools import wraps
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import tensorflow as tf
import yaml
from pydantic.v1 import BaseModel

from hailo_model_optimization.acceleras.encoding.encoding_flow import EncodingFlowGraph
from hailo_model_optimization.acceleras.encoding.encoding_layer import HailoModelEncoding
from hailo_model_optimization.acceleras.encoding.encoding_sub_ops import EncodingSubOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv_add import BaseHailoConvAdd
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_crosscorrelation_dw import HailoCrossCorrelationDW
from hailo_model_optimization.acceleras.hailo_layers.hailo_dense import HailoDense
from hailo_model_optimization.acceleras.hailo_layers.hailo_io import (
    HailoCacheInputLayer,
    HailoInputLayer,
    HailoOutputLayer,
)
from hailo_model_optimization.acceleras.hailo_layers.hailo_matmul import HailoMatmul
from hailo_model_optimization.acceleras.hailo_layers.hailo_non_nn_core_output_layer import HailoNonNNCoreOutputLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_postprocess import HailoPostprocess, PostProcessConfig
from hailo_model_optimization.acceleras.hailo_layers.op_factories import (
    gen_acceleras_layers_from_hn,
    load_precision_config,
)
from hailo_model_optimization.acceleras.model.hailo_model.layer_equiv_set import LayersEquivSet
from hailo_model_optimization.acceleras.model.hailo_model.model_flow import ModelFlow
from hailo_model_optimization.acceleras.model.preprocess.preprocess import add_preprocess
from hailo_model_optimization.acceleras.model.utils.set_float64 import SetFloat64
from hailo_model_optimization.acceleras.model.utils.set_output_split_precision_zp import SetOutputSplitPrecisionZP
from hailo_model_optimization.acceleras.model.utils.set_signed_output import SetSignedOutput
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import (
    ModelOptimizationConfig,
    update_nested,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    NpzExportMode,
    OpStates,
    OptimizationTarget,
    PostprocessTarget,
    PrecisionMode,
    QuantizationAlgorithms,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
    AccelerasInitializationError,
    AccelerasNumerizationError,
    AccelerasValueError,
    InconsistentEncodingError,
)
from hailo_model_optimization.acceleras.utils.distributed_utils import (
    DistContextInfo,
    manage_layers_devices,
    tf_device_wrapper,
)
from hailo_model_optimization.acceleras.utils.flow_state_utils import ModelState
from hailo_model_optimization.acceleras.utils.hn_npz_utils import NpzWrap, QNpzWrap
from hailo_model_optimization.acceleras.utils.logger import default_logger
from hailo_model_optimization.acceleras.utils.params_loader import load_params


# TODO: untested
def update_graph(foo):
    @wraps(foo)
    def wrap(self, *args, **kwargs):
        retval = foo(self, *args, **kwargs)
        self._functional_model = None
        # TODO: do we want to build all layers or only modified ones
        for layer in self.layers.values():
            for op in layer.atomic_ops:
                op.built = False
            layer.built = False
        if self._infer_encoding:
            self.model_encoding.built = False
        self.built = False
        self.predict_function = None
        self.train_function = None
        self.test_function = None
        return retval

    return wrap


SHAPE_DEPENDENT_OPS = [HailoDense, HailoCrossCorrelationDW, HailoMatmul]


class CacheConfig(BaseModel):
    prefill_size: int
    cache_size: int
    cache_mapping: dict
    write_pointer_mapping: dict


class HailoModel(tf.keras.Model):
    output_shapes: List[list]  # list of output shapes of the model

    # TODO: All hn and npz utils should be in a dedicated package and used here.
    def __init__(
        self,
        hn_dict=None,
        logger=None,
        postproc_cb=None,
        preproc_cb=None,
        nms_config: Optional[PostProcessConfig] = None,
        optimization_target: Union[OptimizationTarget, str] = None,
        lora_adapter_name: str = None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.supported_states = {OpStates.FP}
        self._logger = logger or default_logger()
        self._preproc_cb = preproc_cb
        self._postproc_cb = postproc_cb
        self.flow = None
        self.dist_info = DistContextInfo()
        self._acceleras_layers = {}
        self._functional_model = None  # TODO: SDK-23661, remove usage of functional model
        self.interlayer_tensors = None  # Created when running in static graph
        self._output_internal_layers = None
        self._equivalence = OrderedDict()  # equiv calss?
        self.internal_layer_outputs = {}  # Used for saving the results of internal layers
        self.internal_layer_inputs = {}  # Used for saving the results of internal layers
        self.train_rescale_factor = False
        self.train_output_scale = True
        self.stop_gradient_layers = []
        self._hn_element = dict()
        self._infer_encoding = False
        self._debug_mode = False
        self.nms_config = nms_config
        self._cache_config = None
        if optimization_target is None:
            raise AccelerasValueError("optimization_target cannot be None")
        self.optimization_target = OptimizationTarget(optimization_target)
        self.lora_adapter_name = lora_adapter_name
        self._base_layer_mapping = None
        self.use_external_gpu_policy = False
        self._original_hn_layers = None

        if hn_dict:
            self.import_hn(hn_dict)
            self.handle_non_emulated_layers()
            if "cache_size" in hn_dict["net_params"]:
                self._cache_config = CacheConfig(
                    cache_size=hn_dict["net_params"]["cache_size"],
                    prefill_size=hn_dict["net_params"]["prefill_size"],
                    cache_mapping={},
                    write_pointer_mapping={},
                )
        if nms_config:
            nms_layer = [
                acceleras_layer
                for acceleras_layer in self._acceleras_layers
                if isinstance(self._acceleras_layers[acceleras_layer], HailoPostprocess)
            ]
            if len(nms_layer) == 1:
                self._acceleras_layers[nms_layer[-1]].config = self.nms_config
            elif len(nms_layer) > 1:
                raise AccelerasImplementationError(
                    "Currently network with multiple NMS post-process layers is not supported yet",
                )
            else:
                # Currently happens with nms_layer in nn_core
                pass
        self._cache = {}

    def _unlock_model(self):
        """
        Unlock the model to allow for changes
        """
        self._tracker.locked = False

    def _lock_model(self):
        """
        Lock the model to prevent changes
        """
        self._tracker.locked = True

    @property
    def model_name(self):
        return self.name

    @model_name.setter
    def model_name(self, name):
        self.name = name

    @property
    def full_name(self):
        return self.name

    def get_sub_model(self, sub_flow: ModelFlow, deepcopy=False) -> "HailoModel":
        """
        Get a HailoModel from a sub-graph of the original flow.
        if deepcopy is false, the sub-model shares the same layers as the original model.
        The given sub-flow has to include dummy input and output nodes (excluding the input and output of the model)
        DEEPCOPY ISN'T SUPPORTED
        """
        input_shapes = [(None,) + shape for shape in self.get_input_shapes()]
        self.compute_output_shape(input_shapes)
        if deepcopy:
            raise NotImplementedError("Deep copy is not supported")
        sub_model = HailoModel(optimization_target=self.optimization_target, logger=self._logger)
        inp_nodes = set(sub_flow.input_nodes)
        out_nodes = set(sub_flow.output_nodes)
        sub_nodes = set(sub_flow.nodes) - inp_nodes - out_nodes
        sub_model_layers = {layer: self.layers[layer] for layer in sub_nodes}
        for inp_layer in inp_nodes:
            real_inputs = sub_flow.successors_sorted(inp_layer)
            real_input_lname = real_inputs[0]
            inp_ind = sub_flow.get_edge_input_index(inp_layer, real_input_lname)
            real_input_layer = self.layers[real_input_lname]
            inp_shape = real_input_layer.input_shapes[inp_ind][1:]
            shapes = [[-1, *inp_shape]]
            sub_model_layers[inp_layer] = HailoInputLayer.from_hn(
                inp_layer,
                {
                    "type": "input_layer",
                    "input_shapes": shapes,
                    "output_shapes": shapes * len(real_inputs),
                    "transposed": False,
                },
                self._logger,
            )
            if isinstance(real_input_layer, BaseHailoNonNNCoreLayer):
                continue
            sub_model_layers[inp_layer].set_output_scale(real_input_layer.input_scales[inp_ind], 0)
            sub_model_layers[inp_layer].set_output_zero_point(real_input_layer.input_zero_points[inp_ind], 0)
            sub_model_layers[inp_layer].atomic_op.set_output_lossy_element(
                copy.copy(real_input_layer.get_input_lossy_elements()[inp_ind]),
                0,
            )
            sub_model_layers[inp_layer].enforce_io_encoding()
            sub_model_layers[inp_layer].force_scalar_encoding = False
        output_shapes = []
        for out_layer in out_nodes:
            real_output_lname = sub_flow.predecessors_sorted(out_layer)[0]
            out_ind = sub_flow.get_edge_output_index(real_output_lname, out_layer)
            real_output_layer = self.layers[real_output_lname]
            out_ind = real_output_layer.resolve_output_index(out_ind)
            out_shape = real_output_layer.output_shapes[out_ind][1:]
            output_shapes.append(out_shape)
            shapes = [[-1, *out_shape]]

            def get_hn_element(engine):
                return {
                    "type": "output_layer",
                    "input_shapes": shapes,
                    "output_shapes": shapes,
                    "transposed": False,
                    "engine": engine,
                }

            if isinstance(real_output_layer, BaseHailoNonNNCoreLayer):
                sub_model_layers[out_layer] = HailoNonNNCoreOutputLayer.from_hn(
                    out_layer,
                    get_hn_element(PostprocessTarget.CPU.value),
                    self._logger,
                )
                continue
            else:
                sub_model_layers[out_layer] = HailoOutputLayer.from_hn(
                    out_layer,
                    get_hn_element(PostprocessTarget.NN_CORE.value),
                    self._logger,
                )
            sub_model_layers[out_layer].set_input_scale(real_output_layer.output_scales[out_ind], 0)
            sub_model_layers[out_layer].set_input_zero_point(real_output_layer.output_zero_points[out_ind], 0)
            sub_model_layers[out_layer].atomic_op.set_input_lossy_element(
                copy.copy(real_output_layer.get_output_lossy_elements()[out_ind]),
                0,
            )
            sub_model_layers[out_layer].enforce_io_encoding()
            sub_model_layers[out_layer].force_scalar_encoding = False
        sub_model.flow = sub_flow
        sub_model.output_shapes = output_shapes
        sub_model.model_name = f"{self.full_name}_submodel"
        sub_model._acceleras_layers = sub_model_layers
        sub_model._hn_element = self._hn_element.copy()
        sub_model._hn_element["output_layers_order"] = sub_flow.output_layer_order
        for supported_state in self.supported_states:
            sub_model.add_supported_state(supported_state)
        return sub_model

    def compile(
        self,
        optimizer="rmsprop",
        loss=None,
        metrics=None,
        loss_weights=None,
        weighted_metrics=None,
        run_eagerly=None,
        steps_per_execution=1,
        save_interlayer=None,
        stop_gradient_layers=None,
        **kwargs,
    ):
        super().compile(
            optimizer=optimizer,
            loss=loss,
            metrics=metrics,
            loss_weights=loss_weights,
            weighted_metrics=weighted_metrics,
            run_eagerly=run_eagerly,
            steps_per_execution=steps_per_execution,
            **kwargs,
        )
        if save_interlayer is not None:
            self.interlayer_tensors = {layer: None for layer in save_interlayer}
        self.stop_gradient_layers = stop_gradient_layers if stop_gradient_layers is not None else []

    @property
    def layers(self) -> Dict[str, BaseHailoLayer]:
        """
        All the layers of the model as dict {layer_name: layer_object}
        Returns:dict with layers, layer name as keys, layer object as value
        """
        return self._acceleras_layers

    @property
    def preproc_cb(self):
        return self._preproc_cb

    @preproc_cb.setter
    def preproc_cb(self, preproc_cb):
        self._preproc_cb = preproc_cb

    def export_hn(self, filename=None):
        """
        Export the model as hn dict (layers info and layers flow)

        Args:
            filename: Optional saves the hn to the given filename
        Returns:
            dict describing the hn

        """
        shapes = [(None,) + shape for shape in self.get_input_shapes()]
        self.compute_output_shape(shapes)
        hn = self._acceleras_model_to_hn()
        if filename is not None:
            with open(filename, "w") as hn_file:
                json.dump(hn, hn_file, indent=4)
        return hn

    def _base_model_layers_to_hn(self):
        """
        Export the base model layers (layers that aren't part of the current adapter) as hn dict
        """
        if self._original_hn_layers is None:
            return {}
        layers = OrderedDict()
        original_hn_layers = json.loads(self._original_hn_layers)
        for lname, layer in original_hn_layers.items():
            if not lname.startswith(f"{self.lora_adapter_name}/"):
                layers[lname] = OrderedDict(layer)
        return layers

    def _acceleras_model_to_hn(self):
        """
        Export the entire model as a valid hn
        """
        layer_hn_element = self._get_serialized_layers()
        hn = OrderedDict()
        hn["name"] = self.model_name
        hn["net_params"] = self._to_hn()
        hn["layers"] = OrderedDict()
        if self.lora_adapter_name is not None:
            hn["layers"].update(self._base_model_layers_to_hn())
        for layer, layer_info in layer_hn_element.items():
            hn["layers"][layer] = OrderedDict(layer_info)
        return hn

    def _to_hn(self):
        """
        Get the model information as hn
        """
        hn = dict()
        if self.lora_adapter_name is not None:
            self._hn_element["output_layers_order"] = [
                lname
                for lname in self._hn_element["output_layers_order"]
                if not lname.startswith(f"{self.lora_adapter_name}/")
            ]
        else:
            self._hn_element["output_layers_order"] = []
        self._hn_element["output_layers_order"].extend(self.flow.output_layer_order)
        return update_nested(hn, self._hn_element)

    def _get_serialized_layers(self) -> Dict[str, dict]:
        """
        Serialize the model layers as hn dict object
        This is based that the output shape is correct for every layer.
        """
        res = {}
        output_nodes = self.flow.output_nodes
        for lname in self.flow.toposort():
            out_degree = self.flow.out_degree[lname]
            layer = self.layers[lname]
            hn_element = layer.to_hn(out_degree=out_degree)
            hn_element["input"] = self.flow.predecessors_sorted(lname)
            hn_element["output"] = self.flow.successors_sorted(lname)
            # set output shapes
            original_output_shapes = copy.deepcopy(hn_element["output_shapes"])

            hn_element["output_shapes"] = []

            for index, suc in enumerate(self.flow.successors_sorted(lname)):
                # !Important Layer Can Have Multiple Outputs with Multiple Shapes

                edge_data = self.flow.get_edge_data(lname, suc)
                effective_index = layer.resolve_output_index(edge_data["output_index"])
                shape = layer.output_shapes[effective_index]
                hn_element["output_shapes"].append([-1, *shape[1:]])
                if len(original_output_shapes[effective_index]) != len(shape):
                    hn_element["output_shapes"][effective_index] = original_output_shapes[effective_index]
                    # work around to solve legacy bug/feature of mismatch fc shapes

            if lname in output_nodes:
                hn_element["output_shapes"] = hn_element["input_shapes"]
            if "params" in hn_element:
                hn_element["params"] = OrderedDict(hn_element["params"])
            res[lname] = hn_element
        return res

    @update_graph
    def import_hn(self, hn: Union[dict, str]):
        """
        Initializes model from hn dict (layers info and layers flow)

        Args:
            hn: dict that describes the hn or filename as str

        """
        if isinstance(hn, str):
            with open(hn) as fp:
                hn_dict = json.load(fp)
        elif isinstance(hn, dict):
            hn_dict = hn
        else:
            raise TypeError(f"Unexpected type {type(hn)} for hn")
        self._logger.debug("Creating graph from hn")
        self._hn_element = hn_dict.get("net_params", dict())

        hn_layers = hn_dict["layers"]  # TODO: use pydantic
        self._base_layer_mapping = {}
        for lname, hn_layer in hn_layers.items():
            self._base_layer_mapping[lname] = hn_layer.get("base_layer", lname)
        if self.lora_adapter_name is not None:
            # keep original_hn_layers as a string for export so that keras wouldn't wrap the dictionary.
            self._original_hn_layers = json.dumps(hn_layers)
            hn_layers = {
                lname: hn_layer
                for lname, hn_layer in hn_layers.items()
                if lname.startswith(f"{self.lora_adapter_name}/")
            }
        # Create acceleras layers (can be moved to the end)
        self._acceleras_layers = self._create_acceleras_layers_from_hn(
            hn_layers, self.optimization_target, self._logger
        )
        output_layer_order = hn_dict.get("net_params", dict()).get("output_layers_order")
        output_layer_order = [lname for lname in output_layer_order if lname in self._acceleras_layers.keys()]
        self.flow = ModelFlow.from_hn_layers(hn_layers, output_layer_order)
        output_shapes = [hn_layers[node]["output_shapes"][0] for node in self.flow.output_nodes]
        self.output_shapes = output_shapes
        self.model_name = hn_dict["name"]
        self.correct_model()
        add_preprocess(self)

    def correct_model(self):
        SetSignedOutput.correct_model(self)
        SetFloat64.correct_model(self)
        SetOutputSplitPrecisionZP.correct_model(self)

    def add_supported_state(self, *states: OpStates, layers=None):
        if layers is None:
            layers = self.layers.keys()
        for state in states:
            state = OpStates(state)
            for lname in layers:
                layer = self.layers[lname]
                if isinstance(layer, BaseHailoLayer):
                    layer.add_supported_state(state)
            if len(self.layers) > 0:
                self.supported_states.add(state)

    def _import_config_dict(self, config: ModelOptimizationConfig, force_translation):
        """
        Initializes model with config
        Args:
            config: dict describing the layers configuration

        """
        self._logger.debug("Loading model configuration from config_file")
        optimization_target = self.optimization_target
        for lname, layer in self._acceleras_layers.items():
            precision_cfg = config.precision_config.layers[lname]
            layer.import_precision_config(precision_cfg, optimization_target)
            if isinstance(layer, BaseHailoNonNNCoreLayer):
                continue
            translation_cfg = config.translation_config.layers.get(lname)
            layer.verify_config(precision_cfg)
            if translation_cfg is not None or force_translation:
                # TODO: the import of translation config should be part of create_io_encoding / create_hw_params logic
                layer.import_translation_config(translation_cfg)

    def export_flow_state(self) -> ModelState:
        """
        export the flow parameters of the model.
        Agregating the parameters of the layers and recursively up to each lossy element of the model.used to modify parameters such as {fully native,lossy,Numeric lossless}  Lossy elements
        """
        return ModelState(
            full_name=self.full_name,
            layers={layer_name: self.layers[layer_name].export_flow_state() for layer_name in self.layers.keys()},
        )

    def import_flow_state(self, model_state: ModelState) -> None:
        """
        Import the flow parameters of the layer.
        instantiating LossyElements to modify the flow
        """
        if self.full_name != model_state.full_name:
            raise AccelerasInitializationError(
                f"while importing flow states, names didn't match. current {self.full_name} and attempted import {model_state.full_name}"
            )
        for layer_name in self.layers.keys():
            self.layers[layer_name].import_flow_state(model_state.layers[layer_name])

    def _is_input_16bits(self, precision_mode):
        return precision_mode in [PrecisionMode.a16_w16_a16, PrecisionMode.a16_w16_a8, PrecisionMode.a16_w16]

    def _get_pred_16bits_mode(self, precision_mode):
        if precision_mode == PrecisionMode.a16_w16:
            return precision_mode
        if precision_mode in [PrecisionMode.a16_w16_a16, PrecisionMode.a16_w16_a8]:
            return PrecisionMode.a16_w16_a16

    def _handle_config_type(self, model_config):
        if isinstance(model_config, str):
            with open(model_config) as cfg_file:
                mo_config = ModelOptimizationConfig(**yaml.load(cfg_file, yaml.SafeLoader))
        elif isinstance(model_config, ModelOptimizationConfig):
            mo_config = model_config
        else:
            raise AccelerasValueError(
                f"Bad type for model_config: {type(model_config).__name__}, model_config must be file, or a class",
            )
        return mo_config

    def _update_max_elwa_config_per_layer(self, mo_config: ModelOptimizationConfig):
        max_elwa = mo_config.globals.max_elementwise_feed_repeat
        for lname, layer_config in mo_config.translation_config.layers.items():
            if lname in self.layers and not isinstance(self.layers[lname], BaseHailoConvAdd):
                continue
            curr_max_elwa = layer_config.max_elementwise_feed_repeat
            if curr_max_elwa is None:
                layer_config.max_elementwise_feed_repeat = max_elwa

    def import_config(self, model_config: ModelOptimizationConfig, force_translation=False):
        """
        Initializes model from config_path (quantization layers info)

        Args:
            model_config: as path str or
            :class:`hailo_model_optimization.acceleras.model_optimization_config.mo_config.ModelOptimizationConfig`.

        """
        mo_config = self._handle_config_type(model_config)

        self._update_max_elwa_config_per_layer(mo_config)

        self._import_config_dict(mo_config, force_translation)

    def export_hw_params(self, include_shared_weights=True) -> dict:
        """
        Export the quantize weights of the model.

        Returns:
            dict: a qnpz of all the hw params
        """
        return self._export_npz(mode=NpzExportMode.QNPZ, include_shared_weights=include_shared_weights)

    def export_acceleras(self, include_shared_weights=True) -> dict:
        """
        Export entier parameters of the model.

        Returns:
            dict: a full serialized acceleras model params
        """
        return self._export_npz(mode=NpzExportMode.ACCELERAS, include_shared_weights=include_shared_weights)

    def export_weights(self, include_shared_weights=True) -> dict:
        """
        Export the native weights of the model.

        Returns:
            dict: native Weights of the model
        """
        return self._export_npz(mode=NpzExportMode.WEIGHTS, include_shared_weights=include_shared_weights)

    def _export_npz(self, mode=NpzExportMode.ACCELERAS, include_shared_weights=True):
        """
        mode can be either 'acceleras' or 'qnpz' or 'weights'
        """
        shapes = [(None,) + shape for shape in self.get_input_shapes()]
        self.compute_output_shape(shapes)
        if not self.built:
            self.build(shapes)

        if mode == NpzExportMode.QNPZ:
            params_wrap = QNpzWrap(dict())
        elif mode == NpzExportMode.ACCELERAS:
            params_wrap = NpzWrap(
                {
                    "mode": np.array([mode.value]),
                    "optimization_target": np.array([self.optimization_target.value]),
                    "supported_states": np.array([state.value for state in self.supported_states]),
                    "params_kind": np.array([4]),
                }
            )
        elif mode == NpzExportMode.WEIGHTS:
            params_wrap = NpzWrap({"params_kind": np.array([3]), "mode": np.array([mode.value])})
        else:
            raise AccelerasValueError(f"Unsupported mode {mode}")

        self._logger.debug("exporting params from acceleras to qnpz dict")
        for layer_name, acceleras_layer in self._acceleras_layers.items():
            if isinstance(acceleras_layer, BaseHailoNonNNCoreLayer) or (
                isinstance(acceleras_layer, HailoOutputLayer)
                and isinstance(
                    self._acceleras_layers[self.flow.predecessors_sorted(layer_name)[-1]],
                    BaseHailoNonNNCoreLayer,
                )
            ):
                continue
            self._logger.debug(f"exporting params for {layer_name}")
            if mode == NpzExportMode.ACCELERAS:
                params_exported = acceleras_layer.export_acceleras(include_shared_weights=include_shared_weights)
            elif mode == NpzExportMode.QNPZ:
                params_exported = acceleras_layer.export_hw_params(include_shared_weights=include_shared_weights)
            elif mode == NpzExportMode.WEIGHTS:
                params_exported = acceleras_layer.export_weights(include_shared_weights=include_shared_weights)

            params_wrap.write_params(layer_name, params_exported)
        return params_wrap.params

    def check_encoding_consistency(self):
        """
        Check the validity of the model's encoding between layers
        """
        conflicting_layers = []
        for u, v in self.flow.edges:
            u_layer = self.layers[u]
            v_layer = self.layers[v]
            if isinstance(u_layer, BaseHailoNonNNCoreLayer) or isinstance(v_layer, BaseHailoNonNNCoreLayer):
                continue
            input_index = self.flow.get_edge_input_index(u, v)
            output_index = self.flow.get_edge_output_index(u, v)
            output_index = u_layer.resolve_output_index(output_index)
            same_scale = np.allclose(u_layer.output_scales[output_index], v_layer.input_scales[input_index])
            same_zp = np.allclose(u_layer.output_zero_points[output_index], v_layer.input_zero_points[input_index])
            same_lossy = (
                u_layer.get_output_lossy_elements()[output_index] == v_layer.get_input_lossy_elements()[input_index]
            )
            if not same_scale or not same_zp or not same_lossy:
                conflicting_layers.append((u, v))
        if len(conflicting_layers) > 0:
            raise InconsistentEncodingError(conflicting_layers)

    def _add_feature_splitter_params(self, acceleras_layer: BaseHailoLayer, params_wrap: QNpzWrap):
        if getattr(acceleras_layer, "from_split_layer", None):
            layer_name = acceleras_layer.split_layer_name
            splitter_params = dict()
            splitter_params.update(acceleras_layer.get_qp_in())
            splitter_params.update(acceleras_layer.get_qp_out())
            splitter_params.update(acceleras_layer.get_limvals())
            prev_qp_out = params_wrap.get_param(layer_name, "qp_out")
            if prev_qp_out is not None:
                curr_qp_out = splitter_params["qp_out"]
                if curr_qp_out[1] != prev_qp_out[1]:
                    splitter_params["qp_out"] = np.array((splitter_params["qp_out"][0], 0))
            params_wrap.write_params(layer_name, splitter_params)

    @update_graph
    def import_weights(self, weights):
        """
        Import weights into the model weights from npz dict
        Args:
            weights: dict or dict like object that contains the model params or filename to an NPZ file

        """
        if isinstance(weights, str):
            weights_dict = load_params(weights)
        else:
            weights_dict = weights

        self._logger.debug("Loading params from npz")
        if not self._acceleras_layers:
            raise AccelerasInitializationError("Trying to load params to network without layers")
        npz_wrap = NpzWrap(weights_dict, base_layer_mapping=self._base_layer_mapping)
        for layer_name, acceleras_layer in self._acceleras_layers.items():
            if isinstance(acceleras_layer, BaseHailoNonNNCoreLayer):
                continue

            layer_params = npz_wrap.get_layer_params(layer_name)
            acceleras_layer.import_weights(layer_params)

            if acceleras_layer.built:
                self._logger.debug(f"Overriding params for layer {layer_name}")
                # TODO: reset encodings (?)
        add_preprocess(self)

    # def export_npz(self, filepath=None, quantized=False, params_kind=3):
    #     """
    #     Export the model's weights/or hw as npz file.

    #     Args:
    #         filepath: destination file path for the new npz
    #         quantized: boolean, should the file be native npz or quantized npz

    #     """
    #     if quantized:
    #         params_dict = self.export_hw_params()
    #     else:
    #         params_dict = self.export_weights(params_kind=params_kind)
    #     if filepath is not None:
    #         np.savez(filepath, **params_dict)
    #     return params_dict

    def create_io_encoding_candidates(self):
        """
        This function is only used in tests, can we remove it?
        """
        for lname in self.flow.toposort():
            acceleras_layer = self.layers[lname]
            # Consume individual tensor statistics and switch to encodings (aka zp/scales)
            #  as single source of truth for the per-tensor representation candidates.
            acceleras_layer.create_io_encoding_candidates()

    def vectorize_scales(self):
        """
        This function is used in optimization flow and tests,
        maybe it should be part of optimization flow?
        """
        for lname in self.flow.toposort():
            acceleras_layer = self.layers[lname]
            if isinstance(acceleras_layer, BaseHailoNonNNCoreLayer):
                continue
            acceleras_layer.vectorize_scales()

    def enforce_encoding(self, training=False, **kwargs):
        """
        WIP.
            Enforce constraints aka compute (formerly, "match") dependent scales.
            Assumes that all params are populated (by constants),
             either by import or by computation from statistics.

            NOTE: currently doesn't invoke layer.infer_encodings (TODO consider doing that!)
                 that happens in layer.call in train flow and at end of create_numerization in PTQ flow.

            TODO: Check for homogeneous activation in add/cond-add (as long as we have scalar elwa factors..?)
                    otherwise should use backtrack across equivset
                    and mandate uniform scalar scale (However, ReLU6 might still be ignored?)

        Args:
            training: False always, except when training *the encodings themselves* (EXPERIMENTAL)

        Returns:

        """
        edges_list = list(self._toposorted_edged())
        self.enforce_constraints(edges_list, training=training, **kwargs)

    def _toposorted_edged(self):
        for lname in self.flow.toposort():
            for successor_name in self.flow.successors_sorted(lname):
                yield lname, successor_name

    def enforce_constraints(self, edges_list, training=False, create_ratio=False, **kwargs):
        """
        just ensures every edge (u,v):
                v.input_scale  = u.output_scale
                v.input_zero_points =  u.output_zero_point

        Args:
            edges_iterator:
            training:

        """
        source_nodes = set()

        for u, v in edges_list:
            layer = self.layers[u]
            successor = self.layers[v]
            if isinstance(layer, BaseHailoNonNNCoreLayer):
                continue
            if u not in source_nodes:
                if create_ratio:
                    layer.update_io_ratio()
                layer.enforce_io_encoding(training=training, **kwargs)
                source_nodes.add(u)
            if not isinstance(successor, BaseHailoNonNNCoreLayer):
                self._set_successors_inputs_encodings(layer, successor)
            if isinstance(successor, HailoOutputLayer):
                successor.enforce_io_encoding(training=training, **kwargs)

    def _set_successors_inputs_encodings(self, layer, successor):
        """
        set inputs scales/zp of successor based on output_scales and output_zp of layer
        Args:
            layer: the layer name
            successor: the successor name

        """
        out_ind = self.flow.get_edge_output_index(layer.full_name, successor.full_name)
        out_ind = self.layers[layer.full_name].resolve_output_index(out_ind)
        output_scale = layer.output_scales[out_ind]
        output_zero_point = layer.output_zero_points[out_ind]

        inp_ind = self.flow.get_edge_input_index(layer.full_name, successor.full_name)
        successor.set_input_scale(output_scale, inp_ind)
        successor.set_input_zero_point(output_zero_point, inp_ind)

    @update_graph
    def import_acceleras(self, npz):
        """
        Import weights and quant data from npz file
        """
        npz_wrap = NpzWrap(npz, base_layer_mapping=self._base_layer_mapping)
        optimization_target_val = npz_wrap.params.get("optimization_target", self.optimization_target.value)
        if optimization_target_val != self.optimization_target.value:
            raise AccelerasValueError(
                f"Model optimization target is {self.optimization_target.value}, "
                f"but the npz file has {optimization_target_val}",
            )

        # Legacy / old Npz support.
        default_states = npz_wrap.params.get("supported_states", [OpStates.FP, OpStates.CALIBRATED, OpStates.QUANTIZED])
        layers_per_state = {state: set() for state in OpStates}
        for lname in self.layers.keys():
            layer_states = npz_wrap.params.get(f"{lname}/layer_supported_states:0", default_states)
            for state in layer_states:
                layers_per_state[OpStates(state)].add(lname)
        for state, layers in layers_per_state.items():
            self.add_supported_state(state, layers=layers)
        self.correct_model()
        for lname in self.flow.toposort():
            acceleras_layer = self._acceleras_layers[lname]
            layer_params = npz_wrap.get_layer_params(lname)
            acceleras_layer.import_acceleras(layer_params)
        add_preprocess(self)

    @update_graph
    def import_hw_params_from_qnpz(self, qnpz, layers=None):
        """
        Loads numerization info (e.g. scales, zero point) from qnpz file
        This function triggers rebuild to numerized layers; TODO: make sure rebuild isn't requirement after numerization
        Args:
            qnpz: qnpz dict
            layers: list of layers to apply numerization on

        """
        qnpz_wrap = QNpzWrap(qnpz, base_layer_mapping=self._base_layer_mapping)
        if layers is None:
            layers = list(self._acceleras_layers.keys())
        for lname in layers:
            acceleras_layer = self._acceleras_layers[lname]
            if isinstance(acceleras_layer, BaseHailoNonNNCoreLayer):
                continue
            layer_params = qnpz_wrap.get_layer_params(lname)
            acceleras_layer.import_qnpz(layer_params)

    def import_qnpz(self, qnpz):
        """
        WIP - this is now for debugging - load the params on layers
        """
        qnpz_wrap = QNpzWrap(qnpz, base_layer_mapping=self._base_layer_mapping)
        for lname in self.flow.toposort():
            acceleras_layer: BaseHailoLayer = self._acceleras_layers[lname]
            layer_params = qnpz_wrap.get_layer_params(lname)
            acceleras_layer._qnpz_all = layer_params

    @property
    def fully_native(self):
        return [(l_name, layer.fully_native) for l_name, layer in self.iterate_layers()]

    def set_lossless(self, *, native_act: Optional[bool] = None):
        """
        Set the model is "lossless" state.
        Args:
            native_act : bool will explicit set the activations to fully native
        """
        for layer in self.layers.values():
            if not isinstance(layer, BaseHailoNonNNCoreLayer):
                layer.fully_native = False
                layer.disable_lossy(native_act=native_act)
                if layer.built:
                    layer.enforce_internal_encoding()

    def enforce_internal_encoding(self):
        for layer in self.layers.values():
            if not isinstance(layer, BaseHailoNonNNCoreLayer) and layer.built:
                layer.enforce_internal_encoding()

    def set_lossy(self, *, native_act: Optional[bool] = None):
        """
        Set the model is "lossy" state.
        Args:
            native_act : bool will explicit set the activations to fully native
        """
        for layer in self.layers.values():
            if not isinstance(layer, BaseHailoNonNNCoreLayer):
                layer.fully_native = False
                layer.enable_lossy(native_act=native_act)
                if layer.built:
                    layer.enforce_internal_encoding()

    def set_native(self):
        """
        Set the model is native state.
        """
        for layer in self.layers.values():
            if not isinstance(layer, BaseHailoNonNNCoreLayer):
                layer.fully_native = True

    def set_quantized(self, quantized_layers=None, native_layers=None, ignored_layers=None):
        """
        Sets lossy to layers to lose accuracy over the quantized values.
        This function triggers rebuild to numerized layers. TODO: make sure rebuild isn't requirement after lossy

        Args:
            quantized_layers: Layers to set as quantized. Defaults to all layers
            native_layers: Layers to set as native. Defaults to the complement of quantized_layers and ignored_layers
            ignored_layers: Layers to leave untouched. Defaults to empty list

        """
        layers_behaviors = self._get_default_and_verify_layers_to_quant(quantized_layers, native_layers, ignored_layers)
        quantized_layers, native_layers, ignored_layers = layers_behaviors
        for lname in self.flow.toposort():
            acceleras_layer = self._acceleras_layers[lname]
            if lname in quantized_layers:
                if not acceleras_layer.is_numerized:
                    raise AccelerasInitializationError(f"Can't quantize {lname} without numerization")
                # TODO instead of defaults use quant config from "full HN" and then the new json..
                acceleras_layer.enable_lossy()
                self._logger.debug(f"Setting lossy for layer {lname}")
            elif lname in native_layers:
                acceleras_layer.disable_lossy()
                self._logger.debug(f"Setting lossless for layer {lname}")
            elif lname in ignored_layers:
                self._logger.debug(f"Ignoring layer {lname}")
            else:
                raise KeyError(f"Layer {lname} missing from all lists in 'set_quantized'")

    def build_with_params(self, npz, qnpz=None, numerized_layers="all", quantized_layers=None):
        """
        loads weights, numerization and toggles lossy layers based on params
        Loads weights from npz using import_weights
        Loads numerization using import_hw_params_from_qnpz
        Toggles lossy layers using set_quantized

        Args:
            npz: npz dict for weight
            qnpz: qnpz dict for numerization
            numerized_layers: layers to numerize
            quantized_layers: layers to quantize

        """
        self.import_weights(npz)
        if qnpz is not None:
            numerized_layers = None if (numerized_layers == "all") else numerized_layers
            self.import_hw_params_from_qnpz(qnpz, numerized_layers)
        if quantized_layers is not None:
            quantized_layers = None if (quantized_layers == "all") else quantized_layers
            self.set_quantized(quantized_layers)

    def _get_default_and_verify_layers_to_quant(
        self,
        quantized_layers=None,
        native_layers=None,
        ignored_layers=None,
    ) -> Tuple[set, set, set]:
        """
        Fills default values for set_quantized behavior and verify the validity of the layers

        Default behaviors:
            quantized_layers - if None - All unspecified layers
            native_layers - if None - All unspecified layers
            ignored_layers - if None - empty set
            * if both native and quantized are default, native_layers will be empty set
        Args:
            quantized_layers:  Layers that will be set as lossy
            native_layers:  Layers that will be set as lossless
            ignored_layers:  Layers to ignore (and leave untouched)

        Returns: Tuple with 3 sets (quantized, native, ignored)

        """
        all_layers = set(self._acceleras_layers.keys())

        if ignored_layers is None:
            ignored_layers = set()
        ignored_layers = set(ignored_layers)

        if quantized_layers is None and native_layers is None:
            quantized_layers = all_layers - ignored_layers
            native_layers = set()
        elif quantized_layers is None and native_layers is not None:
            native_layers = set(native_layers)
            quantized_layers = all_layers - ignored_layers - native_layers
        elif quantized_layers is not None and native_layers is None:
            quantized_layers = set(quantized_layers)
            native_layers = all_layers - ignored_layers - quantized_layers
        else:
            quantized_layers = set(quantized_layers)
            native_layers = set(native_layers)

        all_layers_set = all_layers
        all_classified_layers = quantized_layers | ignored_layers | native_layers
        missing_layers = all_layers_set - all_classified_layers
        if missing_layers:
            raise ValueError(f"Missing layers when called set_quantized {missing_layers}")
        invalid_layers = all_classified_layers - all_layers_set
        if invalid_layers:
            raise ValueError(f"Invalid layers when called set_quantized {invalid_layers}")
        if quantized_layers & ignored_layers:
            raise ValueError(f"Layers configured both as quantized and ignore {quantized_layers & ignored_layers}")
        if quantized_layers & native_layers:
            raise ValueError(f"Layers configured both as quantized and native {quantized_layers & native_layers}")
        if native_layers & ignored_layers:
            raise ValueError(f"Layers configured both as ignore and native {ignored_layers & native_layers}")

        return quantized_layers, native_layers, ignored_layers

    @classmethod
    def _create_acceleras_layers_from_hn(cls, hn_layers, optimization_target, logger=None) -> Dict[str, BaseHailoLayer]:
        """
        Creates acceleras layers from hn dict
        Args:
            hn_layers:  hn['layers'] field from the hn dict

        Returns: dict with the layer names as keys and acceleras layers as values {layer_name: layer_object}

        """
        acceleras_layers = dict()
        for layer_name, layer_data in hn_layers.items():
            new_acceleras_layers = gen_acceleras_layers_from_hn(layer_name, layer_data, optimization_target, logger)
            new_acceleras_layers = load_precision_config(
                layer_name, new_acceleras_layers, layer_data, optimization_target
            )
            acceleras_layers.update(new_acceleras_layers)
            # the below assumes the new-layers is OrderedDict with order aligned to hn's output_shapes
            #   e.g. for the case of feature-split-via-slices see split_creator.py
            for _outp_shape, _layer in zip(layer_data["output_shapes"], new_acceleras_layers.values()):
                if isinstance(_layer, BaseHailoNonNNCoreLayer):
                    continue
                _layer.output_shape_hn = _outp_shape
        return acceleras_layers

    def enable_internal_encoding(self):
        """
        Set the internal encoding state of all the atomic ops to value
        """
        for _, layer in self.iterate_layers(False):
            layer.enable_internal_encoding()

    def export_disable_internal_encoding(self):
        """
        Run disable_internal_encoding but export the changes that needs to be mapde, it export them as dict
        """
        return self._run_disable_internal_encoding(export_model_state=True)

    def disable_internal_encoding(self, force_endocing_layers=[]):
        """
        disable all encoding. Unless the input or the output of the layer is native
        export model state used to get a dict of described changes instead of performing them
        """
        self._run_disable_internal_encoding(force_endocing_layers=force_endocing_layers, export_model_state=False)

    def _run_disable_internal_encoding(self, force_endocing_layers: list = [], export_model_state: bool = False):
        native_layers = [
            lname for lname in self.flow.toposort() if isinstance(self.layers[lname], BaseHailoNonNNCoreLayer)
        ]
        model_state_dict = {}
        native_output = [lname for lname in self.flow.toposort() if self.layers[lname].is_native_output]
        native_input = [lname for lname in self.flow.toposort() if self.layers[lname].is_native_input]

        force_out_decoding_layers = force_endocing_layers + native_layers + native_output + self.flow.output_nodes
        force_in_encoding_layers = native_layers + native_input + self.flow.input_nodes

        made_change = True
        while made_change:
            made_change = False
            for lname in force_out_decoding_layers:
                for suc in self.flow.successors_sorted(lname):
                    if suc not in force_in_encoding_layers:
                        force_in_encoding_layers.append(suc)
                        made_change = True
            for lname in force_in_encoding_layers:
                for pred in self.flow.predecessors_sorted(lname):
                    if pred not in force_out_decoding_layers:
                        force_out_decoding_layers.append(pred)
                        made_change = True

        for lname, layer in self.iterate_layers(False):
            layer_dict = layer.disable_internal_encoding(
                encode_inputs=lname in force_in_encoding_layers,
                decode_outputs=lname in force_out_decoding_layers,
                quant_inputs=lname in force_in_encoding_layers,
                export_model_state=export_model_state,
            )
            if export_model_state and layer_dict is not None:
                model_state_dict.update(layer_dict)
        if export_model_state:
            return model_state_dict

    def _non_nn_core_layer_disable_encoder_decoder(self, layer: BaseHailoNonNNCoreLayer):
        layer.disable_inputs_decoding()
        layer.disable_outputs_encoding()

    def _non_nn_core_layer_enable_encoder_decoder(self, layer: BaseHailoNonNNCoreLayer):
        if any(
            isinstance(self.layers[pred], BaseHailoLayer) for pred in self.flow.predecessors_sorted(layer.full_name)
        ):
            input_scales = [1] * layer.num_inputs
            input_zero_points = [0] * layer.num_inputs
            for u, v in self.flow.in_edges(layer.full_name):
                out_ind = self.flow.get_edge_output_index(u, v)
                inp_ind = self.flow.get_edge_input_index(u, v)
                if isinstance(self.layers[u], BaseHailoLayer):
                    input_scales[inp_ind] = self.layers[u].output_scales[out_ind]
                    input_zero_points[inp_ind] = self.layers[u].output_zero_points[out_ind]
            layer.enable_inputs_decoding(input_scales, input_zero_points)

        if any(isinstance(self.layers[pred], BaseHailoLayer) for pred in self.flow.successors_sorted(layer.full_name)):
            output_scales = [1] * layer.num_outputs
            output_zero_points = [0] * layer.num_outputs
            for u, v in self.flow.out_edges(layer.full_name):
                out_ind = self.flow.get_edge_output_index(u, v)
                inp_ind = self.flow.get_edge_input_index(u, v)
                if isinstance(self.layers[v], BaseHailoLayer):
                    input_scales[inp_ind] = self.layers[v].output_scales[out_ind]
                    input_zero_points[inp_ind] = self.layers[v].output_zero_points[out_ind]
            layer.enable_output_encoding(output_scales, output_zero_points)

    def iterate_layers(self, skip_non_nn_core=True):
        """
        Args:
            skip_non_nn_core: bolean to indicate if we skip non_nn_core layers

        Returns:  Iterator for lname, hailo_layer

        """
        for lname in self.flow.toposort():
            hailo_layer = self.layers[lname]
            if skip_non_nn_core and isinstance(hailo_layer, BaseHailoNonNNCoreLayer):
                continue
            yield lname, hailo_layer

    def compute_output_shape(self, input_shape):
        return self.compute_and_verify_output_shape(input_shape, verify_layer_inputs_shape=False)

    def compute_and_verify_output_shape(self, input_shape, verify_layer_inputs_shape=True):
        input_nodes = self.flow.input_nodes
        if len(input_nodes) == 1:
            if not isinstance(input_shape, list):
                input_shape = [input_shape]
        all_layers_output_shapes = {node: data for node, data in zip(input_nodes, input_shape)}
        # calculate the shapes of all the ops
        for lname in self.flow.toposort():
            preds = self.flow.predecessors_sorted(lname)
            layer = self.layers[lname]
            if len(preds) == 0:
                is_input = self.flow.nodes[lname].get("is_input", False)
                if is_input:
                    layer_input_shapes = all_layers_output_shapes[lname]
                else:  # const layer
                    layer_input_shapes = input_shape[0]
            else:
                layer_input_shapes = []
                for pred in preds:
                    if self.layers[pred].num_outputs == 1:
                        layer_input_shapes.append(all_layers_output_shapes[pred])
                    else:
                        edge = self.flow.edges[(pred, lname)]
                        shape = all_layers_output_shapes[pred][edge["output_index"]]
                        layer_input_shapes.append(shape)
                if len(preds) == 1:
                    layer_input_shapes = layer_input_shapes[0]
            if verify_layer_inputs_shape:
                if layer.num_inputs <= 1:
                    layer.verify_layer_inputs_shape([layer_input_shapes])
                else:
                    layer.verify_layer_inputs_shape(layer_input_shapes)
            if isinstance(layer, BaseHailoLayer):
                layer_output_shape = layer.compute_output_shape(layer_input_shapes)
            elif isinstance(layer, BaseHailoNonNNCoreLayer):  # no valid output shape calculation method
                layer_output_shape = tf.TensorShape(None)
            else:  # non base layer, no compute_output_shape implementation (removed from keras3.x)
                layer_inputs = [tf.keras.Input(batch_size=shape[0], shape=shape[1:]) for shape in layer_input_shapes]
                layer_output_shape = layer(layer_inputs).shape
            all_layers_output_shapes[lname] = layer_output_shape
        # set output shapes
        results = []
        output_nodes = self.flow.output_nodes
        for output_node in output_nodes:
            preds = list(self.flow.predecessors(output_node))
            if len(preds) != 1:
                raise RuntimeError(
                    f"Unexpected output predecessors count in layer {self.full_name}, predecessors: {preds}"
                )
            pred = preds[0]
            results.append(all_layers_output_shapes[pred])

        if len(results) == 1:
            return results[0]
        return results

    @manage_layers_devices
    def build(self, input_shape):
        if isinstance(input_shape, dict):
            input_shape = [
                input_shape[node]
                for node in self.flow.input_nodes
                if not isinstance(self.layers[node], HailoCacheInputLayer)
            ]

        if any(isinstance(self.layers[lname], HailoCacheInputLayer) for lname in self.flow.input_nodes):
            # cache layers are not part of the input shape, batch has to be taken from the first non-cache layer
            self._add_cache_layers_input_shapes(input_shape, input_shape[0][0])

        for index, lname in enumerate(self.flow.input_nodes):
            if self.layers[lname].conversion_type is not None:
                if len(self.flow.input_nodes) == 1 and not isinstance(input_shape, list):
                    input_shape = (input_shape[0],) + self.layers[lname].input_spec.shape[1:]
                else:
                    input_shape[index] = (input_shape[index][0],) + self.layers[lname].input_spec.shape[1:]
        self.compute_output_shape(input_shape)

        for lname, layer in self.layers.items():
            if not layer.built:
                with tf_device_wrapper(layer.gpu_index):
                    if (
                        getattr(self.layers[lname], "is_const_input", False)
                        and len(self.flow.predecessors_sorted(lname)) == 0
                    ):
                        # Layer is input layer or const_input
                        if len(self.flow.input_nodes) == 1 and not isinstance(input_shape, list):
                            layer.build(input_shape)
                        else:
                            layer.build(input_shape[0])
                    elif isinstance(self.layers[lname], HailoCacheInputLayer):
                        layer.build(input_shape[self.flow.input_nodes.index(lname)])
                    else:
                        layer.build(layer.input_shapes)
        if self._infer_encoding and not self.model_encoding.built:
            self.model_encoding.build(input_shape)
        self._build_input_shape = input_shape
        self.built = True

    def call(
        self,
        inputs,
        layer_postproc_cb=None,
        training=False,
        save_internal_list=None,
        ignore_preproc=False,
        skip_encoding=False,
        **kwargs,
    ):
        """
        Calls the model with given inputs
        Args:
            inputs:  list or dict, if list should be the same order as self.flow.input_nodes
            layer_postproc_cb:  post process function, inferred on each layer's output
            **kwargs:

        Returns:  list with outputs, same order as self.flow.output_nodes

        """
        encoding_tensors = self._call_encoding(inputs, skip_encoding, training=training)
        inputs = self.inputs_as_dict(inputs)
        if self.preproc_cb is not None and not ignore_preproc:
            inputs = self.preproc_cb(inputs)
        inferred_layers = set()
        inference_results = dict()
        internal_layer_outputs = dict()
        save_internal_list = [] if save_internal_list is None else save_internal_list
        _run_eagerly = self.run_eagerly
        # Infer layers
        for lname in self.flow.toposort():
            current_inputs = self._get_layer_inputs(lname, inputs, inference_results)
            layer_encoding = (
                {k: v for k, v in encoding_tensors.items() if k.startswith(lname)}
                if encoding_tensors is not None
                else None
            )
            output = self._call_layer(
                lname,
                current_inputs,
                layer_postproc_cb,
                training=training,
                encoding_tensors=layer_encoding,
                cache_config=self._cache_config,
                **kwargs,
            )
            inferred_layers.add(lname)
            if _run_eagerly:
                self._clean_results(inference_results, inferred_layers)
            if lname in save_internal_list or self._debug_mode:
                self.internal_layer_outputs[lname] = output
            if lname in save_internal_list and self._debug_mode:
                self.internal_layer_inputs[lname] = current_inputs
            if self.interlayer_tensors is not None and lname in self.interlayer_tensors:
                self.interlayer_tensors[lname] = output
            if self._output_internal_layers is not None and lname in self._output_internal_layers:
                internal_layer_outputs[lname] = output
            if lname in self.stop_gradient_layers:
                output = tf.stop_gradient(output)
            self._add_inference_results(lname, output, inference_results)
        outputs = self._outputs_dict_to_list(inference_results)
        outputs = outputs[0] if len(outputs) == 1 else outputs

        if self._postproc_cb is not None:
            outputs = self._postproc_cb(outputs)

        if self._output_internal_layers is not None:
            return outputs, [internal_layer_outputs[lname] for lname in self._output_internal_layers]
        return outputs

    def _add_inference_results(self, lname, output, inference_results):
        layer_results = {}
        if self.layers[lname].num_outputs == 1:
            output = [output]
        for i in range(self.layers[lname].num_outputs):
            layer_results[i] = output[i]
        inference_results[f"{lname}"] = layer_results

    def inputs_as_dict(self, inputs) -> dict:
        """
        Verifies the validity of the input and converts it to dict
        Args:
            inputs:  inputs as a tensor, numpy, list or dict

        Returns:dict of inputs {input_layer_name: input_value}

        """
        if isinstance(inputs, list):
            inputs = self._inputs_list_to_dict(inputs)
        elif isinstance(inputs, dict):
            pass
        else:
            inputs = self._inputs_list_to_dict([inputs])

        return inputs

    def _inputs_list_to_dict(self, inputs: list) -> dict:
        """
        Converts inputs from list object to dict object. list order should be based on ``HailoModel.flow.input_nodes``
        Args:
            inputs: inputs to convert

        Returns: dict with inputs, key is layer name, value is the input value

        """
        return {input_node: inputs[index] for index, input_node in enumerate(self.flow.input_nodes)}

    def _outputs_dict_to_list(self, outputs: dict) -> list:
        """
        Converts outputs from dict object to list object. list order should be based on ``HailoModel.flow.output_nodes``
        Args:
            outputs: outputs to convert

        Returns: list with output, ordered based on ``HailoModel.flow.output_nodes``
        """
        result = []
        for output_node in self.flow.output_nodes:
            for i in range(len(outputs[output_node])):
                result.append(outputs[output_node][i])
        return result

    def _get_layer_inputs(self, lname, model_inputs, inference_results):
        """
        Get inputs for the inference of given lname
        Args:
            lname: layer name for which the inputs are will be used
            model_inputs: the inputs given by the user
            inference_results: results of intermediate layers

        Returns: inputs (list or tensor) for lname
        """
        preds = self.flow.predecessors_sorted(lname)
        if len(preds) == 0:  # Layer is input layer or const_input
            if getattr(self.layers[lname], "is_const_input", False):
                # for const input we take only the batch size from the input
                inputs = [v for v in model_inputs.values()][0]
            elif isinstance(self.layers[lname], HailoCacheInputLayer):
                # generates dummy inputs values for cache input layer
                inputs = tf.zeros(self.layers[lname].input_shape, dtype=tf.float32)
            else:
                inputs = model_inputs[lname]
            if len(inputs.shape) == 2:
                inputs = tf.reshape(inputs, (-1, 1, 1, inputs.shape[1]))
        else:
            inputs = []
            for pred in preds:
                out_index = self.flow.get_edge_output_index(pred, lname)
                out_index = self.layers[pred].resolve_output_index(out_index)
                inputs.append(inference_results[pred][out_index])
            inputs = inputs[0] if len(inputs) == 1 else inputs
        return inputs

    def _call_layer(self, lname, inputs, layer_postproc_cb=None, training=False, encoding_tensors=None, **kwargs):
        """
        The magic happens here, this function is responsible for getting a layers results
        Args:
            lname: Name of the layer to infer
            inputs: inputs to layer
            layer_postproc_cb: callback hook on a layer inference

        Returns: results of layer inference

        """
        acceleras_layer = self._acceleras_layers[lname]
        with tf_device_wrapper(acceleras_layer.gpu_index):
            outputs = acceleras_layer(inputs, training=training, encoding_tensors=encoding_tensors, **kwargs)
        if layer_postproc_cb is not None:
            outputs = layer_postproc_cb(lname, self._acceleras_layers[lname], inputs, outputs)

        return outputs

    def _clean_results(self, inference_results: dict, inferred_layers: set):
        """
        Deletes results of layers that are no longer needed for inference
        Args:
            inference_results: dict with results of layers
            inferred_layers: set of layers that have already been inferred

        """
        output_nodes = set(self.flow.output_nodes)
        layers_with_result = set(inference_results.keys()) - output_nodes
        for lname in layers_with_result:
            successors = set(self.flow.successors(lname))
            if successors.issubset(inferred_layers):
                del inference_results[lname]

    def _make_graph_model(self):
        """
        Create a model with static graph structure, "Functional" in Keras lingo,
        with call() operating on symbolic (rather than numeric) input tensors.
        Enables support of some exclusively non-eager APIs ( summary(), plot_model(), maybe more..).
        NOTE: One would also expect the resultant <hailo_model_instance.graph_model> to be more performant,
              but actually it seems that Keras is smart enough to "un-eager" stuff so <hailo_model_instance>
               can be used on it's own just as well, with all the eager mode convenience.
        Still, worth keeping in mind the "functional" <hailo_model_instance.graph_model> as an option,
              when debugging performance issues..
        """
        # TODO: SDK-23661
        symbolic_inputs = self._get_placeholder_inputs()

        # Call is invoked explicitly to prevent from the Model to be treated as a black box layer.
        real_outputs = self(symbolic_inputs, ignore_preproc=True)

        self._functional_model = tf.keras.Model(inputs=symbolic_inputs, name="graph_model", outputs=real_outputs)
        return self._functional_model

    def iter_equiv_sets(self, equiv_set_algo, index_equiv_set_algo=None):
        """
        iter over the equiv sets
        Args:
             equiv_set_algo: (QuantizationAlgorithms) Enum represents the algorithm
            index_equiv_set_algo: Default None, if not return the specific built dict of equiv_sets

        Returns: iterator of all the equiv sets

        """
        algo_equivalence, _ = self.get_equiv_sets(equiv_set_algo, index_equiv_set_algo)
        for valid_subgraph in self.flow.get_components():
            for layer in valid_subgraph.toposort():
                if layer not in algo_equivalence:
                    # check if layer in source if equiv set
                    continue
                yield algo_equivalence[layer]

    def get_equiv_sets(self, equiv_set_algo: QuantizationAlgorithms, index_equiv_set_algo=None):
        """
        return the equiv sets for specific algo. if not created - build it and save is self._equivalence
        Args:
            equiv_set_algo: (QuantizationAlgorithms) Enum represents the algorithm
            index_equiv_set_algo: Default None, if not return the specific built dict of equiv_sets

        Returns

        """
        existing_index = (index_equiv_set_algo is not None) and (
            len(self._equivalence[equiv_set_algo.value]) > index_equiv_set_algo
        )
        if (equiv_set_algo.value in self._equivalence) and existing_index:
            return self._equivalence[equiv_set_algo.value][existing_index], existing_index
        else:
            return self.build_equiv_sets(equiv_set_algo)

    def build_equiv_sets(self, equiv_set_algo):
        """
        build equiv sets for specific algo
        Args:
             equiv_set_algo: (QuantizationAlgorithms) Enum represents the algorithm
        Returns: equiv_set, identify_equiv_set_algo

        """
        equiv_sets = OrderedDict()
        handled_sources = set()
        for layer_name in self.flow.toposort():
            equiv_set = LayersEquivSet.build_layer_equiv_set(self, layer_name, equiv_set_algo, handled_sources)
            if equiv_set is None:
                continue
            handled_sources |= set(equiv_set.source_layers)
            equiv_sets[layer_name] = equiv_set
        if equiv_set_algo.value not in self._equivalence:
            self._equivalence[equiv_set_algo.value] = []
        identify_equiv_set_algo = len(self._equivalence[equiv_set_algo.value])
        self._equivalence[equiv_set_algo.value].append(equiv_sets)
        return self._equivalence[equiv_set_algo.value][identify_equiv_set_algo], identify_equiv_set_algo

    def summary(self, line_length=None, positions=None, print_fn=None):
        """Summarizes the network info"""
        # TODO: SDK-23661
        self._functional_model = self._functional_model or self._make_graph_model()
        # TODO: Do we want to implement summary by iterating over the layers manually?
        return self._functional_model.summary(line_length, positions, print_fn)

    def plot_model(
        self,
        to_file="model.png",
        show_shapes=False,
        show_dtype=False,
        show_layer_names=True,
        rankdir="TB",
        expand_nested=False,
        dpi=96,
    ):
        """Visualizes the network"""
        # TODO: SDK-23661
        self._functional_model = self._functional_model or self._make_graph_model()
        return tf.keras.utils.plot_model(
            self._functional_model,
            to_file=to_file,
            show_shapes=show_shapes,
            show_dtype=show_dtype,
            show_layer_names=show_layer_names,
            rankdir=rankdir,
            expand_nested=expand_nested,
            dpi=dpi,
        )

    def add_layer(self, layer, edges, is_input=False):
        self._unlock_model()
        self.layers[layer.full_name] = layer
        self._lock_model()
        self.flow.insert_node(layer.full_name, edges, is_input)
        self.built = False
        self.predict_function = None
        self.train_function = None
        self.test_function = None

    def remove_layer(self, layer, connect_succ_and_pred=True):
        self.layers.pop(layer.full_name, None)
        self.flow.remove_layer(layer.full_name, connect_succ_and_pred=connect_succ_and_pred)
        self.built = False
        self.predict_function = None
        self.train_function = None
        self.test_function = None

    def replace_layer(self, new_layer, old_layer, use_new_name=False):
        if use_new_name:
            self._unlock_model()
            self.layers[new_layer.full_name] = new_layer
            self._lock_model()
            self.flow.replace_layer_manual(old_layer.full_name, new_layer.full_name)
            self.layers.pop(old_layer.full_name, None)
        else:
            self._unlock_model()
            self.layers[old_layer.full_name] = new_layer
            self._lock_model()
            self.flow.replace_layer(old_layer.full_name, new_layer.full_name)
        self.built = False
        self.predict_function = None
        self.train_function = None
        self.test_function = None

    def set_debug_mode(self):
        self.interlayer_tensors = {layer: None for layer in self.layers.keys()}
        for layer in self._acceleras_layers.values():
            for aop in layer.atomic_ops:
                aop.debug_mode = True

    def reset_debug_mode(self):
        for layer in self._acceleras_layers.values():
            for aop in layer.atomic_ops:
                aop.debug_mode = False

    def set_output_interal_layers(self, layers: List[str]):
        """
        Set the internal layers to output during model call.

        Note that setting this will change the model predict, train and test functions to output the internal layers.

        For example, to get the output of the layers `'hailo_model/conv3'` and `'hailo_model/reduce_max1'`:
        >>> model: HailoModel = ...  # model that contain the layers 'hailo_model/conv3' and 'hailo_model/reduce_max1'
        >>> input_sample = ...  # input dataset
        >>> # set the model to output the internal layers
        >>> model.set_output_interal_layers(['hailo_model/conv3', 'hailo_model/reduce_max1'])
        >>> output, internal_layers_output = model.predict(input_sample)
        >>> internal_layers_output
        [<hailo_model/conv3 output>, <hailo_model/reduce_max1 output>]
        >>> # reset the model to output only the output layers
        >>> model.reset_output_interal_layers()
        >>> output = model.predict(input_sample)
        """
        self._output_internal_layers = layers
        # set the predict, train and test functions to None to force the model to rebuild the graph
        self._old_predict_function = self.predict_function
        self._old_train_function = self.train_function
        self._old_test_function = self.test_function
        self.predict_function = None
        self.train_function = None
        self.test_function = None

    def reset_output_interal_layers(self):
        """
        Reset the model to output only the output layers.

        See `set_output_interal_layers` for more information.
        """
        self._output_internal_layers = None
        self.predict_function = self._old_predict_function
        self.train_function = self._old_train_function
        self.test_function = self._old_test_function
        self._old_predict_function = None
        self._old_train_function = None
        self._old_test_function = None

    def get_input_shapes(self):
        input_shapes = list()
        for in_node in self.flow.input_nodes:
            in_shape = self._acceleras_layers[in_node].input_spec.shape[1:]
            input_shapes.append(in_shape)
        return input_shapes

    def _get_placeholder_inputs(self):
        input_shapes = self.get_input_shapes()
        symbolic_inputs = list()
        for in_shape in input_shapes:
            symbolic_inputs.append(tf.keras.Input(shape=in_shape))
        return symbolic_inputs

    def _call_encoding(self, inputs, skip=False, training=False):
        """
        Get the model's encoding variables.

        Args:
            inputs: Model's inputs.
            skip: If skip or not self._infer_encoding will skip the encoding calculation. Defaults to False.

        Returns:
            dict: A dictionary of the form "encoding_name": encoding_value.

        """
        if not self._infer_encoding or skip:
            return None
        dependant_tensors = self.model_encoding(inputs, training=training)
        return dependant_tensors

    def enable_encoding_infer(self):
        """
        Enable encoding inference during model call.
        """
        self._unlock_model()
        encoding_flow = self.get_encoding_flow()
        encoding_inference_flow = encoding_flow.solve()
        self.model_encoding = HailoModelEncoding(f"{self.full_name}/model_encoding", encoding_inference_flow)
        self._infer_encoding = True
        for _, layer in self.iterate_layers():
            layer.infer_in_build = False
            layer.enforce_internal_encoding_in_call = False
        self.built = False
        self.model_encoding.built = False
        self.predict_function = None
        self.train_function = None
        self.test_function = None
        self._lock_model()

    def disable_encoding_infer(self):
        """
        Disable encoding inference during model call.
        """
        dependant_tensors = self.model_encoding(tf.constant(0.0))
        encoding_numpy = {k: (v.numpy() if hasattr(v, "numpy") else v) for k, v in dependant_tensors.items()}
        self._infer_encoding = False
        for _, layer in self.iterate_layers():
            layer.update_encoding(encoding_numpy)
            for op in layer.atomic_ops:
                op.update_encoding(encoding_numpy)
            layer.infer_in_build = True
            layer.enforce_internal_encoding_in_call = True
        self.predict_function = None
        self.train_function = None
        self.test_function = None

    def get_encoding_flow(self):
        """
        return encoding flow graph with the model's encodings, and their respected constraints.
        """
        flow = EncodingFlowGraph()
        for _, layer in self.iterate_layers():
            flow.update(layer.get_encoding_flow())

        enc = EncodingSubOp(flow)
        for u, v, data in self.flow.edges(data=True):
            if isinstance(self.layers[u], BaseHailoNonNNCoreLayer) or isinstance(
                self.layers[v],
                BaseHailoNonNNCoreLayer,
            ):
                continue
            input_index = data["input_index"]
            output_index = data["output_index"]
            output_index = self.layers[u].resolve_output_index(output_index)
            op_output, op_output_index = list(self.layers[u].iterate_output_ops())[output_index]
            u_scale_encoding_output = f"{op_output.full_name}/output_scale:{op_output_index}"
            u_zero_point_encoding_output = f"{op_output.full_name}/output_zero_point:{op_output_index}"
            op_input, op_input_index = list(self.layers[v].iterate_input_ops())[input_index]
            v_scale_encoding_input = f"{op_input.full_name}/input_scale:{op_input_index}"
            v_zero_point_encoding_input = f"{op_input.full_name}/input_zero_point:{op_input_index}"
            if not (op_output.encoding_const and op_input.encoding_const):
                enc.identity(u_scale_encoding_output, v_scale_encoding_input)
                enc.identity(u_zero_point_encoding_output, v_zero_point_encoding_input)
        return flow

    def get_layers_scope_map(self) -> Dict[str, str]:
        """
        Get a mapping for layer names within
        """
        scope_map = dict()
        for lname in self.layers.keys():
            net_scope, base_name = lname.split("/", 1)
            scope_map.setdefault(net_scope, set())
            scope_map[net_scope].add(base_name)
        return scope_map

    def resolution_reduction_prepare(self):
        """Method disables the model io shape verification

        Returns:
            Tuple: False, {} if are Ops that cant be disable the verification
            Tuple: True, Dictionary with the original input shapes.
        """
        original_input_shapes = {}
        for lname, layer in self.layers.items():
            if any(isinstance(layer, op) for op in SHAPE_DEPENDENT_OPS):
                self._logger.warning(
                    f"Can't reduce resolution when the model contains {lname} layer. "
                    "Optimizing the model without reducing its input resolution.",
                )
                return False, original_input_shapes

        for layer in self.layers.values():
            layer.ignore_io_shapes_verification = True

        input_layers_hn = {x: self.layers[x].hn_element for x in self.flow.input_nodes}
        original_input_shapes = {x: [[-1] + y["input_shapes"][0][1:]] for x, y in input_layers_hn.items()}
        return True, original_input_shapes

    def handle_non_emulated_layers(self):
        """
        For layers that are not participating in the emulation graph, we assume the predecessor is an input layer and
        we change the graph input shape to the non-emulated layer output shape and skip shape verification for both
        input layer and non emulated layer.
        """
        for lname, layer in self.layers.items():
            if not layer.in_emulation_graph:
                layer.ignore_io_shapes_verification = True
                for pred_name in self.flow.predecessors_sorted(lname):
                    pred = self.layers[pred_name]
                    if not isinstance(pred, HailoInputLayer):
                        raise AccelerasValueError(
                            f"layer {pred_name} is not an input layer but it's successor {lname} is not in "
                            f"the emulation graph",
                        )
                    pred.ignore_io_shapes_verification = True
                    pred.input_spec = tf.keras.layers.InputSpec(shape=[None, *layer.output_shape_hn[1:]])

    def get_layer_name_with_scope(self, lname: str):
        """
        add scope to layer name if missing.
        """
        layer_parts = lname.split("/", 1)
        if len(layer_parts) == 2:
            if lname in self.layers:
                return lname
            else:
                raise ValueError(f"Layer {lname} not found in model")
        elif len(layer_parts) == 1:
            scopes_map = self.get_layers_scope_map()
            viable_scopes = list()
            for scope, layers in scopes_map.items():
                lname = layer_parts[0]
                if lname in layers:
                    viable_scopes.append(scope)
            if len(viable_scopes) > 1:
                raise ValueError(f"Too many scopes with layer {lname}")
            elif len(viable_scopes) < 1:
                raise ValueError(f"Layer {lname} not found in model")
            else:
                return f"{viable_scopes[0]}/{lname}"
        else:
            raise ValueError(f"Invalid layer name {lname}")

    def get_layer_names_from_glob(self, glob_exp):
        """
        Expand glob syntax in both cases (with, and without scope)
        """
        layer_parts = glob_exp.split("/", 1)
        if len(layer_parts) == 2:
            glob_expressions = [glob_exp]
        elif len(layer_parts) == 1:
            scope_map = self.get_layers_scope_map()
            glob_expressions = []
            for scope in scope_map.keys():
                glob_expressions.append(f"{scope}/{glob_exp}")
        else:
            raise ValueError(f"Invalid glob expression {glob_exp}")

        matches = []
        for expression in glob_expressions:
            curr_match = fnmatch.filter(self.layers.keys(), expression)
            matches.extend(curr_match)
        return matches

    @property
    def bit_exact(self):
        """
        Retrun true if all the layers are bit-exact enabled, False if all the
        layers are disables, else return list
        """
        layer_states = [layer.bit_exact for layer in self.layers.values()]
        if all(layer_states):
            return True
        if all([not v for v in layer_states]):
            return False
        return layer_states

    @bit_exact.setter
    def bit_exact(self, value: bool):
        """
        Set the bit_exact state of all the layers to value
        """
        if value and not self.bit_exact_emulation_supported:
            raise AccelerasNumerizationError(f"cant support bit exact on model {self.full_name}")

        for _, layer in self.iterate_layers():
            layer.bit_exact = value if layer.bit_exact_emulation_supported else False

    @property
    def bit_exact_supported(self) -> bool:
        # check if all layers are bit_exact_supported (bit exact is implemted for this layer)
        return np.array([layer.bit_exact_supported for _, layer in self.iterate_layers()]).all()

    @property
    def bit_exact_emulation_supported(self) -> bool:
        # check if all layers are bit_exact_emulation_supported (can run a bit exact emulation - (bit exact and not native!!))
        return np.array([layer.bit_exact_emulation_supported for _, layer in self.iterate_layers()]).all()

    def is_jit_compile_supported(self, training=False):
        """
        Check if all the layers support JIT compilation
        """
        return all(layer.is_jit_compile_supported(training) for layer in self.layers.values())

    def _add_cache_layers_input_shapes(self, input_shapes_list, batch_size):
        for i, node in enumerate(self.flow.input_nodes):
            if isinstance(self.layers[node], HailoCacheInputLayer):
                input_shapes_list.insert(i, [batch_size, *self.layers[node].input_spec.shape[1:]])
