import itertools
import os
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Iterator, List, Optional, Set, Tuple

import numpy as np
import tensorflow as tf
from tensorflow.python.eager import context

from hailo_model_optimization.acceleras.atomic_ops import base_op
from hailo_model_optimization.acceleras.atomic_ops.activation_op import ActivationOp
from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.atomic_ops.bias_add_op import AddBiasOp
from hailo_model_optimization.acceleras.atomic_ops.conv_stripped_op import ConvStrippedOp
from hailo_model_optimization.acceleras.atomic_ops.element_wise_add_op import ElementwiseAddOp
from hailo_model_optimization.acceleras.encoding.encoding_flow import EncodingFlowGraph
from hailo_model_optimization.acceleras.encoding.encoding_sub_ops import EncodingSubOp
from hailo_model_optimization.acceleras.hailo_layers.layer_flow import LayerFlow
from hailo_model_optimization.acceleras.lossy_elements.quant_element import (
    AccumulatorQuantElement,
    BaseQuantElement,
    QuantElement,
)
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import update_nested
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerPrecisionConfig,
    LayerTranslationConfig,
    LayerWeightsClippingConfig,
)
from hailo_model_optimization.acceleras.statistics.statistics_base import BasicTypeTuple, TypeStats
from hailo_model_optimization.acceleras.statistics.statistics_factory import ImportedStats, Statistics
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    BiasMode,
    DataPath,
    EquivClassification,
    LayerHandlerType,
    OpStates,
    OptimizationTarget,
    PrecisionMode,
    QuantizationAlgorithms,
    StatsState,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
    AccelerasImportError,
    AccelerasInitializationError,
    AccelerasNumerizationError,
    AccelerasValueError,
    InconsistentEncodingError,
)
from hailo_model_optimization.acceleras.utils.export.export_utils import QnpzExporter
from hailo_model_optimization.acceleras.utils.export.layer_export_utils import generic_ops_hw_params_export
from hailo_model_optimization.acceleras.utils.flow_state_utils import LayerState
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams
from hailo_model_optimization.acceleras.utils.opt_utils import (
    get_accumulator_bits_by_precision_mode,
    get_input_bits_by_precision_mode,
    get_output_bits_by_precision_mode,
    get_quant_element_by_data_path,
)
from hailo_model_optimization.acceleras.utils.qnpz_params import DEFAULT_VAL_SCALE, DEFAULT_VAL_ZP


class BaseHailoLayer(base_op.BaseOp, ABC):
    """
    CompositeOp is the closest representation of HNLayer
    a composite op is composed of AtomicOps and shouldn't use TF / Torch directly.

    [WIP] CompositeOp can be created from hn layer and be exported back to hn
    """

    supported_bias_mode_acceleras = [
        BiasMode.double_scale_initialization,
        BiasMode.single_scale_decomposition,
        BiasMode.double_scale_decomposition,
    ]
    supported_precision_mode_acceleras = [
        PrecisionMode.a8_w4,
        PrecisionMode.a8_w8,
        PrecisionMode.a16_w16,
        PrecisionMode.a8_w4_a8,
        PrecisionMode.a8_w4_a16,
        PrecisionMode.a8_w8_a8,
        PrecisionMode.a8_w8_a16,
        PrecisionMode.a16_w16_a8,
        PrecisionMode.a16_w8_a8,
        PrecisionMode.a16_w8_a16,
        PrecisionMode.a16_w16_a16,
        PrecisionMode.a16_w4_a8,
        PrecisionMode.a16_w4_a16,
    ]

    # Precision modes supported by emulator in specified layer
    SUPPORTED_PRECISION_MODE: Set[PrecisionMode]
    # Bias modes supported by emulator in specified layer
    SUPPORTED_BIAS_MODE: Set[BiasMode]
    # Indicated whether quantization groups feature is supported by emulator in specified layer
    SUPPORTED_QUANTIZATION_GROUPS: bool

    def __init__(self, name, logger=None, **kwargs):
        super().__init__(name=name, logger=logger, **kwargs)
        self._has_hw_params = False
        self._layer_flow = self._build_flow()
        self._negative_slope_exponent_fix_shift = 0
        self.infer_in_build = True
        self.enforce_internal_encoding_in_call = True
        self.cross_layer_precision_mode = None
        self._hn_output_split = 1
        self._layer_precision_config = None
        self.ignore_io_shapes_verification = False
        self._debug_mode = False
        self._original_output_stats = None
        self._strong_force_range = False
        self.eq_vec_out = 1

        # This value will let know the model in which GPU a layer should be run (-1 TF will decide )
        self.gpu_index = -1

        # TODO: remove this line after all layers implement new encoding.
        self.encoding_const = True

        self.output_lossy_element_external = None
        self.input_lossy_element_external = None
        self.output_split_precision_zp = None

        self.shared_weights = False

    def finalize_from_hn(self, hn_element):
        self.shared_weights = hn_element.get("base_layer", self.full_name) != self.full_name
        return super().finalize_from_hn(hn_element)

    @classmethod
    @abstractmethod
    def from_hn(cls, lname, hn_element, logger=None):
        """
        OVERRIDE in subclasses with whatever is needed to create instance from an HN slice
        """

    @property
    def strong_force_range(self):
        return self._strong_force_range

    @strong_force_range.setter
    def strong_force_range(self, value):
        self._strong_force_range = value

    @abstractmethod
    def _build_flow(self) -> LayerFlow:
        pass

    def call(self, inputs, training=False, fully_native=None, encoding_tensors: dict = None, **kwargs):
        if not fully_native and (training or context.executing_eagerly()) and encoding_tensors is None:
            if self.enforce_internal_encoding_in_call:
                self.enforce_internal_encoding(training=training)
            else:
                self.fast_enforce_internal_encoding(training=training)

        if encoding_tensors is not None:
            self.update_encoding(encoding_tensors)

        intermidiate_results = self._prepare_inputs(inputs)
        for op_name in self._layer_flow.toposort_ops():
            preds = self._layer_flow.predecessors_sorted(op_name)
            op_inputs = [intermidiate_results[pred] for pred in preds]
            op = self._layer_flow.get_op(op_name)
            if len(op_inputs) == 1:
                op_inputs = op_inputs[0]
            op_encoding = (
                {k: v for k, v in encoding_tensors.items() if k.startswith(op_name)}
                if encoding_tensors is not None
                else None
            )
            op_result = op(
                op_inputs,
                training=training,
                fully_native=fully_native,
                encoding_tensors=op_encoding,
                **kwargs,
            )
            intermidiate_results[op_name] = op_result
        if self._debug_mode:
            self.intermidiate_results = intermidiate_results
        result = self._prepare_outputs(intermidiate_results)
        return result

    def add_supported_state(self, *states: OpStates):
        for state in states:
            for op in self.atomic_ops:
                op.supported_states.add(state)
            self.supported_states.add(state)

    @property
    def transpose_width_features(self):
        """
        Used in bias correction to select reduction axes when computing the bias diff
        override if the layer supports 'transpose_width_features' hn config
        """
        return False

    @property
    def num_inputs(self):
        return self._layer_flow.num_inputs

    @property
    def num_outputs(self):
        return self._layer_flow.num_outputs

    def enable_internal_encoding(self):
        """
        enable internal encoding for all the ops
        """
        for op in self.atomic_ops:
            op.enable_internal_encoding()
        self.enforce_internal_encoding_in_call = True

    def disable_internal_encoding(
        self, encode_inputs=None, decode_outputs=None, quant_inputs=None, *, export_model_state=False
    ):
        input_ops = [op[0].full_name for op in self.iterate_input_ops()]
        output_ops = [op[0].full_name for op in self.iterate_output_ops()]
        model_state_dict = {}
        for op in self.atomic_ops:
            suc_is_native = self._suc_op_is_native(op)
            pred_is_native = self._pred_op_is_native(op)
            suc_force_encoding = self._suc_op_force_encoding(op, input_ops, encode_inputs)

            op_encode_inputs = pred_is_native or (encode_inputs and op.full_name in input_ops)
            op_decode_outputs = suc_is_native or (decode_outputs and op.full_name in output_ops) or suc_force_encoding
            op_quant_inputs = pred_is_native or (quant_inputs and op.full_name in input_ops)

            op_state = op.disable_internal_encoding(
                decode_outputs=op_decode_outputs,
                encode_inputs=op_encode_inputs,
                quant_inputs=op_quant_inputs,
                export_model_state=export_model_state,
            )

            if export_model_state and op_state is not None:
                model_state_dict.update(op_state)

        self.enforce_internal_encoding_in_call = False

        if export_model_state:
            return model_state_dict

    def _suc_op_force_encoding(self, op, input_ops, encode_inputs):
        for suc in self._layer_flow.successors_sorted(op.full_name):
            if suc in input_ops and encode_inputs:
                return True
        return False

    def _suc_op_is_native(self, op):
        for suc in self._layer_flow.successors_sorted(op.full_name):
            if suc in self._layer_flow.get_outputs():
                continue
            if self._layer_flow.get_op(suc).fully_native:
                return True
        return False

    def _pred_op_is_native(self, op):
        for pred in self._layer_flow.predecessors_sorted(op.full_name):
            if pred in self._layer_flow.get_inputs():
                continue
            if self._layer_flow.get_op(pred).fully_native:
                return True
        return False

    @property
    def is_native_input(self):
        return np.any([op[0].fully_native for op in self.iterate_input_ops()])

    @property
    def is_native_output(self):
        return np.any([op[0].fully_native for op in self.iterate_output_ops()])

    @property
    def fully_native(self):
        """
        Return true if all the atomic ops are fully, False if all the
        atomic ops are disables, else return list
        """
        op_states = [op.fully_native for op in self.atomic_ops]
        if all(op_states):
            return True
        if all([not v for v in op_states]):
            return False
        return op_states

    @fully_native.setter
    def fully_native(self, value):
        """
        Set the fully_native state of all the atomic ops to value
        """
        for op in self.atomic_ops:
            op.fully_native = value

    @property
    def bit_exact(self):
        """
        Return true if all the atomic ops are bit-exact enabled, False if all the
        atomic ops are disables, else return list
        """
        op_states = [op.bit_exact for op in self.atomic_ops]
        if all(op_states):
            return True
        if all([not v for v in op_states]):
            return False
        return op_states

    @bit_exact.setter
    def bit_exact(self, value):
        """
        Set the bit_exact state of all the atomic ops to value
        """
        if value and not self.bit_exact_supported:
            raise AccelerasNumerizationError(f"cant support bit exact{self.full_name}")

        for op in self.atomic_ops:
            op.bit_exact = value if (op.bit_exact_supported and not op.fully_native) else False

    @property
    def bit_exact_emulation_supported(self) -> bool:
        return np.array([op.bit_exact_emulation_supported for op in self.atomic_ops]).all()

    @property
    def bit_exact_supported(self) -> bool:
        return np.array([op.bit_exact_supported for op in self.atomic_ops]).all()

    @property
    def is_precision_transparent(self) -> bool:
        return False

    @property
    def encoding_const(self):
        """
        Return true if all the atomic ops' encodings are constant, False if all
        aren't constant, and list otherwise.
        """
        op_states = [op.encoding_const for op in self.atomic_ops]
        if all(op_states):
            return True
        if all([not v for v in op_states]):
            return False
        return op_states

    @encoding_const.setter
    def encoding_const(self, value):
        """
        Set the encoding_const state of all the atomic ops to value
        """
        for op in self.atomic_ops:
            op.encoding_const = value

    def _prepare_inputs(self, inputs):
        """Prepare inputs for inference, convert array to dict"""
        input_nodes = self._layer_flow.get_inputs()
        if len(input_nodes) == 1:
            input_data = {input_nodes[0]: inputs}
        else:
            if len(inputs) != len(input_nodes):
                raise ValueError(
                    f"Inputs and input nodes not the same length - inputs: {len(inputs)}, nodes: {len(input_nodes)}",
                )
            input_data = {node: data for node, data in zip(input_nodes, inputs)}
        return input_data

    def _prepare_outputs(self, intermidiate_results):
        """Prepare output after inference, convert dict to array"""
        results = []
        output_nodes = self._layer_flow.get_outputs()
        for output_node in output_nodes:
            preds = list(self._layer_flow.predecessors(output_node))
            if len(preds) != 1:
                raise RuntimeError(
                    f"Unexpected output predecessors count in layer {self.full_name}, predecessors: {preds}"
                )
            pred = preds[0]
            results.append(intermidiate_results[pred])
        if len(results) == 1:
            results = results[0]
        return results

    @property
    def atomic_ops(self):
        return self._layer_flow.get_ops()

    def is_differentiable(self):
        return all(op.is_differentiable() for op in self.atomic_ops)

    def get_algo_callback(self, algo):
        classifier_callbacks_by_algo = {
            QuantizationAlgorithms.equalization: self.get_equalization_handler_type,
            QuantizationAlgorithms.params_sorter: self.get_params_sorter_handler_type,
            QuantizationAlgorithms.dead_channels_removal: self.get_dead_channels_removal_handler_type,
            QuantizationAlgorithms.quarot: self.get_quarot_handler_type,
        }
        return classifier_callbacks_by_algo[algo]

    @abstractmethod
    def get_equalization_handler_type(self, predecessor_index=None) -> EquivClassification:
        """

        Args:
            predecessor_index: the predecessor index in the inputs list of the current layer.

        Returns:

        """

    def get_params_sorter_handler_type(self, predecessor=None):
        # TODO: Change to abstract once we start implementation
        raise AccelerasImplementationError("acceleras does not support params sorter is not supported yet")

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        # TODO: Change to abstract once we start implementation
        raise AccelerasImplementationError("acceleras does not support dead_channels is not supported yet")

    def get_quarot_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def export_hn(self):
        # TODO: Change to abstract once we start implementation
        raise AccelerasImplementationError("acceleras export hn is not supported yet")

    def export_weights(self, include_shared_weights=True):
        if self.shared_weights and not include_shared_weights:
            return dict()
        return self._export_weights()

    def _export_weights(self):
        """
        OVERRIDE in subclasses with whatever is needed to
        export all pre-trained native weights. The weight should be returned as dict.
        """
        # TODO: Change to abstract once we finish implementation
        raise AccelerasImplementationError(f"export weights is not implemented for {self.full_name}")

    # region Export Hw Params

    def _export_layers_qp_defaults(self) -> dict:
        """This is some legacy that we export qp in and qpout, but this information is not allways need it"""
        params = dict()
        params.update(self.get_qp_in())
        params.update(self.get_qp_out())
        params.update({key: np.array(val, dtype=np.float32) for key, val in self.get_limvals().items()})
        params["zero_point_in"] = np.int32(params["qp_in"][0])
        params["zero_point_out"] = np.int32(params["qp_out"][0])
        return params

    def _export_layer_params(self):
        output_zero_points = self._update_scalar_to_scale(self.output_zero_points, self.output_scales)
        input_zero_points = self._update_scalar_to_scale(self.input_zero_points, self.input_scales)

        output_scales = self._change_list_to_np_array(self.output_scales)
        input_scales = self._change_list_to_np_array(self.input_scales)

        output_zero_points = self._change_list_to_np_array(output_zero_points)
        input_zero_points = self._change_list_to_np_array(input_zero_points)

        layer_params = {
            "layer_params/output_scales": output_scales,
            "layer_params/input_scales": input_scales,
            "layer_params/output_zero_points": output_zero_points,
            "layer_params/input_zero_points": input_zero_points,
            "layer_params/negative_slopes_correction_factor": np.array(
                self._negative_slope_exponent_fix_shift,
                np.float32,
            ),
        }
        return layer_params

    def _export_ops_hw_params(self) -> dict:
        """Generic ways to export Hw Params from ops
        TODO Add a better generic that will add sufix to the ops on a generic way
        """
        params = generic_ops_hw_params_export(self.atomic_ops)
        return params

    def _layer_dependent_hw_params_modifications(self, params: dict) -> dict:
        """If a layer needs to modify or add parameters on the layer level this is the function to do so
        Leaving export_hw_params as a clean API function
        """
        return params

    def export_hw_params(self, layer_level_params=None, include_shared_weights=True):
        params = {} if layer_level_params is None else layer_level_params
        self.enable_lossy()  # Dont know if this should be Done (maybe fix the test)
        self.enforce_internal_encoding()
        params.update(self._export_ops_hw_params())
        params.update(self._export_layers_qp_defaults())
        params.update(self._export_layer_params())
        params = self._layer_dependent_hw_params_modifications(params)
        exporter = QnpzExporter(self.__class__.__name__, params)
        return exporter.export(include_shared_weights=include_shared_weights or not self.shared_weights)

    def enable_lossy(self, *, native_act: Optional[bool] = None, **kwargs):
        for aop in self.atomic_ops:
            if isinstance(aop, ActivationOp) and (native_act is not None):
                aop.fully_native = native_act
            aop.enable_lossy(**kwargs)

    def disable_lossy(self, *, native_act: Optional[bool] = None, **kwargs):
        for aop in self.atomic_ops:
            if isinstance(aop, ActivationOp) and (native_act is not None):
                aop.fully_native = native_act
            aop.disable_lossy(**kwargs)

    def _input_stats_ops(self):
        for op, input_index in self.iterate_input_ops():
            yield op, input_index

    def _output_stats_ops(self):
        for op, output_index in self.iterate_output_ops():
            yield op, output_index

    def start_stats_collection(self, stats_cfg: tuple = BasicTypeTuple, output_hist=False, preact_hist=False):
        act_stats_cfg_out = stats_cfg
        self._original_output_stats = None
        if output_hist:
            act_stats_cfg_out = (*stats_cfg, TypeStats.DYNAMIC_HISTOGRAM)

        act_stats_cfg_preact = stats_cfg
        if preact_hist:
            act_stats_cfg_preact = (*stats_cfg, TypeStats.DYNAMIC_HISTOGRAM)

        for op, input_index in self._input_stats_ops():
            op.start_stats_collection(stats_cfg=stats_cfg, collect_inputs=True, collect_output=False)
        for op, output_index in self._output_stats_ops():
            op.start_stats_collection(stats_cfg=act_stats_cfg_out, collect_inputs=False, collect_output=True)
        for op in self._iterate_act_ops():
            op.start_stats_collection(stats_cfg=act_stats_cfg_preact, collect_inputs=True, collect_output=False)

    def stop_stats_collection(self):
        for op in self.atomic_ops:
            op.stop_stats_collection()

    def check_encoding_consistency(self):
        conflicting_edges = []
        for u, v in self._layer_flow.edges:
            if self._layer_flow.is_placeholder(u) or self._layer_flow.is_placeholder(v):
                continue
            u_layer = self._layer_flow.get_op(u)
            v_layer = self._layer_flow.get_op(v)
            input_index = self._layer_flow.get_edge_input_index(u, v)
            output_index = self._layer_flow.get_edge_output_index(u, v)
            same_scale = np.all(u_layer.output_scales[output_index] == v_layer.input_scales[input_index])
            same_zp = np.all(u_layer.output_zero_points[output_index] == v_layer.input_zero_points[input_index])
            # Assume post quantization
            same_lossy = self._validate_consistency_of_lossy_elements(u_layer, output_index, v_layer, input_index)
            if not same_scale or not same_zp or not same_lossy:
                conflicting_edges.append((u, v))
        if len(conflicting_edges) > 0:
            raise InconsistentEncodingError(conflicting_edges)

    def _validate_consistency_of_lossy_elements(self, u_layer, output_index: int, v_layer, input_index: int):
        return u_layer.output_lossy_elements[output_index] == v_layer.input_lossy_elements[input_index]

    def enforce_internal_encoding(self, training=False, **kwargs):
        """
        [Optional]
        Compute all "dependent" encoding params from the minimal set of "source" ones,
         under the **layer-specific HW constraints** which this method thus implements.
        The "source" set always includes the input&output scale ("inward-scales-propagation"),
        These might be dependent too (e.g. input scale is previous layer's output scale),
          as enforced by Model's method synonymous to this one, invoked prior and setting them.
        For more details see:
        https://hailotech.atlassian.net/wiki/spaces/ML/pages/943259731/Vectorized+Scales
        https://hailotechcom-my.sharepoint.com/:w:/g/personal/alexf_hailo_ai/EfzOd6bw2B9BtUmzdqC9oa4B92CKMI-HKXXFgPPWjbJzwQ?e=2cTysg

        Normall all params are numbers and the context is quantization or completion of partial import.
        However, for scales-training, some of "independent" params are Variables,
         and this method is called in graph context to permit backprop.
        So ideally implementation should be pure TF.

        Implementations should strive to mostly consist of invoking atomic-op infer_encodings(),
        and then syncing its output encodings into input encodings of next atomic op.
        However, for complex layers, the need for layer-specific ordering and syncing is accepted.
        """
        """ TODO consider making abstract OR use simple loop baseline:
            for op in self.atomic_ops:
                op.enforce_encoding()
        """

    def fast_enforce_internal_encoding(self, **kwargs):
        self.enforce_internal_encoding(**kwargs)

    def lossy_status_summary(self):
        """
        handy utility for debug.
        """
        ret = "\n"
        for aop in self.atomic_ops:
            ret += "---" + aop.full_name + " lossy elements"
            ret += "input: " + "".join([str(el) for el in aop.input_lossy_elements])
            ret += "weights: " + "".join([str(el) for el in aop.weight_lossy_elements])
            ret += "output: " + "".join([str(el) for el in aop.output_lossy_elements])
        return ret + "\n"

    @property
    def input_scales(self):
        # TODO consider switching to abstract,
        #  in certain layer might create tricky bugs if not overriden properly
        input_scales = []
        for op, index in self.iterate_input_ops():
            input_scales.append(op.input_scales[index])
        return input_scales

    @property
    def output_scales(self):
        # TODO consider switching to abstract,
        #  in certain layer might create tricky bugs if not overriden properly
        output_scales = []
        for op, index in self.iterate_output_ops():
            output_scales.append(op.output_scales[index])
        return output_scales

    @property
    def input_zero_points(self):
        # see property input_scales comment
        input_zero_points = []
        for op, index in self.iterate_input_ops():
            input_zero_points.append(op.input_zero_points[index])
        return input_zero_points

    @property
    def output_zero_points(self):
        # see property input_scales comment
        output_zero_points = []
        for op, index in self.iterate_output_ops():
            output_zero_points.append(op.output_zero_points[index])
        return output_zero_points

    @property
    def input_scale(self):
        if self.num_inputs != 1:
            raise AttributeError(
                f"The layer has multiple inputs. "
                f'Hence the notion of "input_scale" is '
                f"ill-defined for the layer - {self.full_name}",
            )
        return self.input_scales[0]

    @property
    def output_scale(self):
        if self.num_outputs != 1:
            raise AttributeError(
                f"The layer has multiple outputs. "
                f'Hence the notion of "output_scale" is '
                f"ill-defined for the layer - {self.full_name}",
            )
        return self.output_scales[0]

    @property
    def input_zero_point(self):
        if self.num_inputs != 1:
            raise AttributeError(
                f"The layer has multiple inputs. "
                f'Hence the notion of "input_zero_point" is '
                f"ill-defined for the layer - {self.full_name}",
            )
        return self.input_zero_points[0]

    @property
    def output_zero_point(self):
        if self.num_outputs != 1:
            raise AttributeError(
                f"The layer has multiple outputs. "
                f'Hence the notion of "output_zero_point" is '
                f"ill-defined for the layer - {self.full_name}",
            )
        return self.output_zero_points[0]

    def _enforce_input_encoding(self):
        pass

    def set_input_scale(self, scale, index):
        inputs = list(self.iterate_input_ops())
        op, op_index = inputs[index]
        op.input_scales[op_index] = scale

    def set_output_scale(self, scale, index):
        outputs = list(self.iterate_output_ops())
        op, op_index = outputs[index]
        op.output_scales[op_index] = scale

    def _force_output_scale(self):
        """
        this function is used to force the output scale of the layer
        """
        pass

    def update_io_ratio(self):
        self._force_output_scale()
        self._create_out_in_scale_ratio()

    def set_input_zero_point(self, zero_point, index):
        inputs = list(self.iterate_input_ops())
        op, op_index = inputs[index]
        op.input_zero_points[op_index] = zero_point

    def set_output_zero_point(self, zero_point, index):
        inputs = list(self.iterate_output_ops())
        op, op_index = inputs[index]
        op.output_zero_points[op_index] = zero_point

    def vectorize_scales(self):
        for op in self.atomic_ops:
            op.vectorize_scales()
        self._create_out_in_scale_ratio()

    def create_io_encoding_candidates(self, translation_config=None):
        """
        Consumes statistics and create I/O encoding-param candidates,
        implicitly including the inputs and outputs to layer..
        """
        self.create_input_encoding_candidates(translation_config=translation_config)
        self.create_output_encoding_candidates(translation_config=translation_config)

    def create_input_encoding_candidates(self, translation_config=None):
        for op, input_index in self._input_stats_ops():
            op.create_input_encoding_candidates(
                input_index=input_index,
                input_lossy_external=self.input_lossy_element_external,
                translation_config=translation_config,
            )

    def create_output_encoding_candidates(self, forced_range=None, translation_config=None):
        for op, output_index in self._output_stats_ops():
            op.create_output_encoding_candidates(
                output_index,
                forced_range,
                output_lossy_external=self.output_lossy_element_external,
                translation_config=translation_config,
                split_precision_zp=self.output_split_precision_zp,
            )
        self._create_out_in_scale_ratio()

    def enforce_io_encoding(self, training=False, **kwargs):
        """
        Method serving model-level infer_encodings (formerly, SCALES MATCHING) functionality.
            If the output encoding (scale&zp) of a layer is constrained
            and fully determined by the input encoding (w.o. regard to internal params!),
            this method implements the relation and overwrite the scale&zp.
            out the output scale of layer, given all input scales, Override in each layer.

        Examples:
              A. Non-arithmetic layers: concat, maxpool, nn-resize - implemented, simple per-channel propagation.
              B. Conv - NOT implemented, I/O scales are related through internal params, which are inferred
                 from I/O scales in infer_encodings (with output scale as independent/unconstrained).
              C. ConvAdd - special case, see there.

        Args:
            train_scales: a boolean to indicate whether it's used in offline or scales-training context.

        Returns:
            Nothing, but should set output_scale and output_zero_point if implemented

        """
        raise AccelerasImplementationError(f"enforce_io_encoding is not implemented for {self.full_name}")

    def _propagate_encoding_forward(self, atomic_op: BaseAtomicOp, enforce_encoding=False):
        """
        propagate the encoding from the atomic op to the next atomic op.
        the encoding will be propagated to the successors
        """
        if enforce_encoding:
            atomic_op.enforce_encoding()
        for suc in self._layer_flow.successors_sorted(atomic_op.full_name):
            suc_op = self._layer_flow.get_op(suc)
            input_index = self._layer_flow.get_edge_input_index(atomic_op.full_name, suc)
            output_index = self._layer_flow.get_edge_output_index(atomic_op.full_name, suc)
            suc_op.input_scales[input_index] = atomic_op.output_scales[output_index]
            suc_op.input_zero_points[input_index] = atomic_op.output_zero_points[output_index]

    def _propagate_encoding_backward(self, atomic_op: BaseAtomicOp, enforce_encoding=False):
        """
        propagate the encoding from the atomic op to the next atomic op.
        the encoding will be propagated to the predecessors
        """
        if enforce_encoding:
            atomic_op.enforce_encoding(forward=False)
        for pred in self._layer_flow.predecessors_sorted(atomic_op.full_name):
            pred_op = self._layer_flow.get_op(pred)
            input_index = self._layer_flow.get_edge_input_index(pred, atomic_op.full_name)
            output_index = self._layer_flow.get_edge_output_index(pred, atomic_op.full_name)
            pred_op.output_scales[output_index] = atomic_op.input_scales[input_index]
            pred_op.output_zero_points[output_index] = atomic_op.input_zero_points[input_index]

    def _export_quant_internal(self, include_shared_weights=True):
        """
        Export the quantization params (dependent and independent) of the layer.
        Can be used for compilation, simlulation, and reconstruction of the layer.

        The keys should look as follows:
            <model_name>/<layer_name>/<op_name>/<param_name>
        """
        self.enable_lossy()
        self.enforce_internal_encoding()
        export_params = {}
        for op in self.atomic_ops:
            export_params.update(
                op.export_quant(include_shared_weights=include_shared_weights or not self.shared_weights)
            )
        export_params.update(self._export_layer_params())
        export_params["layer_params/eq_vec_out"] = self.eq_vec_out

        return export_params

    def _validate_scales(self, scales, params, key):
        layer_params_scale = params.get(key, None)
        if layer_params_scale is None:
            return
        for i in range(len(scales)):
            scale = layer_params_scale[i]
            unpadded_scales = scale[~np.isnan(scale)]
            if not np.array_equal(unpadded_scales, scales[i]):
                raise AccelerasNumerizationError(
                    f"in layer {self.full_name} {key} are differnet from params {unpadded_scales} {scales[i]}",
                )

    def _validate_zero_points(self, zero_points, params, key):
        if self.vector_zp_in:
            return
        layer_params_zp = params.get(key, None)
        if layer_params_zp is None:
            return
        for i in range(len(zero_points)):
            zero_point_vector = layer_params_zp[i]
            unpadded_zero_point = zero_point_vector[~np.isnan(zero_point_vector)]
            if np.any([zero_point != unpadded_zero_point[0] for zero_point in unpadded_zero_point]):
                raise AccelerasNumerizationError(f"in layer {self.full_name} {key} are differnet")
            if unpadded_zero_point[0] != zero_points[i]:
                raise AccelerasNumerizationError(
                    f"in layer {self.full_name} {key} {zero_points[i]} are differnet from params {zero_point_vector[0]}",
                )

    def _validate_io_encoding_params(self, params):
        self._validate_scales(self.output_scales, params, "layer_params/output_scales")
        self._validate_scales(self.input_scales, params, "layer_params/input_scales")
        self._validate_zero_points(self.output_zero_points, params, "layer_params/output_zero_points")
        self._validate_zero_points(self.input_zero_points, params, "layer_params/input_zero_points")

    def _export_layer_metadata(self):
        # TODO because we still export QNPZ we cant add it to the op

        export_vals = {
            "ignore_io_shapes_verification": self.ignore_io_shapes_verification,
            "layer_supported_states": np.array(sorted([state.value for state in self.supported_states])),
        }

        if op := self.activation_atomic_op:
            export_vals["act_fully_native"] = op.fully_native
        return export_vals

    def _import_layer_metadata(self, npz):
        self.ignore_io_shapes_verification = npz["ignore_io_shapes_verification"]
        if op := self.activation_atomic_op:
            op.fully_native = npz["act_fully_native"]

    def export_acceleras(self, include_shared_weights=True):
        """Export layer for npz file, where the stractus of the dict is the same as the atomic ops structure"""
        # TODO SDK-48985 not all keys are consistance , there is a test on hailo layer
        # TODO that enables this check.
        npz = dict()
        if OpStates.CALIBRATED in self.supported_states:
            npz.update(self.export_stats())
        if OpStates.QUANTIZED in self.supported_states:
            npz.update(self.export_quantize_params(include_shared_weights=include_shared_weights))
        npz.update(self.export_weights(include_shared_weights=include_shared_weights))
        npz.update(self._export_layer_metadata())
        return npz

    def import_acceleras(self, acceleras_params):
        self.import_weights(acceleras_params)
        if OpStates.CALIBRATED in self.supported_states:
            self.import_stats(acceleras_params)
        if OpStates.QUANTIZED in self.supported_states:
            self._import_quant_internal(acceleras_params)

            # TODO SDK-48985 we might want to enforce encodings after import to set the state
            # self.enforce_internal_encoding()
            # self.enable_lossy()

        self._import_layer_metadata(acceleras_params)

    def export_quantize_params(self, include_shared_weights=True):
        export_params = self._export_quant_internal(include_shared_weights=include_shared_weights)
        return export_params

    def export_flow_state(self) -> LayerState:
        """
        Export the flow parameters of the layer.
        Aggregating the parameters of the atomic ops, such as {fully native,lossy,Numeric lossless}  Lossy elements
        """
        return LayerState(
            full_name=self.full_name,
            atomic_ops={op.full_name: op.export_flow_state() for op in self.atomic_ops},
            enforce_internal_encoding_in_call=self.enforce_internal_encoding_in_call,
        )

    def import_flow_state(self, layer_state: LayerState) -> None:
        """
        Import the flow parameters of the layer.
        instantiating LossyElements to modify the flow
        """
        if self.full_name != layer_state.full_name:
            raise AccelerasInitializationError(
                f"while importing flow states, names didn't match. current {self.full_name} and attempted import {layer_state.full_name}"
            )
        self.enforce_internal_encoding_in_call = layer_state.enforce_internal_encoding_in_call
        for op in self.atomic_ops:
            op.import_flow_state(layer_state.atomic_ops[op.full_name])

    def get_limvals(self):
        qnpz = dict()
        if len(self.input_scales) >= 1:
            input_scale = self.input_scales[0]
            input_zp = self.input_zero_points[0]
            lossy_element = self.get_input_lossy_elements()[0]
            max_value = tf.cast(lossy_element.max_value - input_zp, input_scale.dtype) * input_scale
            min_value = tf.cast(lossy_element.min_value - input_zp, input_scale.dtype) * input_scale
            qnpz["limvals_in"] = (np.min(min_value), np.max(max_value))

        output_scale = self.output_scales[0]
        output_zp = self.output_zero_points[0]
        lossy_element = self.get_output_lossy_elements()[0]
        max_value = tf.cast(lossy_element.max_value - output_zp, output_scale.dtype) * output_scale
        min_value = tf.cast(lossy_element.min_value - output_zp, output_scale.dtype) * output_scale
        qnpz["limvals_out"] = (np.min(min_value), np.max(max_value))
        return qnpz

    def get_qp_in(self):
        qnpz = dict()
        if len(self.input_scales) == 0:
            return qnpz
        input_scale = self.input_scales[0]
        zero_point_in = self.input_zero_points[0]
        qp_scale = input_scale[0] if input_scale.shape != () else input_scale
        qp_zp = zero_point_in[0] if np.array(zero_point_in).shape != () else zero_point_in
        if not np.all(input_scale == qp_scale):
            qp_scale = DEFAULT_VAL_SCALE
        if not np.all(zero_point_in == qp_zp):
            qp_zp = DEFAULT_VAL_ZP
        qp_in = np.array((qp_zp, qp_scale))
        qnpz["qp_in"] = qp_in
        return qnpz

    def get_qp_out(self):
        qnpz = dict()
        output_scale = self.output_scales[0]
        zero_point_out = self.output_zero_points[0]
        qp_scale = output_scale[0] if output_scale.shape != () else output_scale
        qp_zp = zero_point_out[0] if np.array(zero_point_out).shape != () else zero_point_out
        if not np.all(output_scale == qp_scale):
            qp_scale = DEFAULT_VAL_SCALE
        if not np.all(zero_point_out == qp_zp):
            qp_zp = DEFAULT_VAL_ZP
        qp_out = np.array((qp_zp, qp_scale))
        qnpz["qp_out"] = qp_out
        return qnpz

    def update_scale_scalar_dof(self, shift):
        # if there is a shift we will want to avoid pushing the scalar
        pass

    def _import_quant_internal(self, params: dict):
        try:
            for op in self.atomic_ops:
                op.import_quant(params)
            self.import_layer_params(params)
        except KeyError as e:
            raise AccelerasImportError(
                f"You are trying to use quantized HAR that was generated with an older DFC."
                f"Please run optimize again with the current version. Missing param: {e.args} in {self.full_name}.",
            )
        self._create_out_in_scale_ratio()
        self._has_hw_params = True

    # def import_qnpz(self, qnpz: dict, skipped_keys: Set[str] = None):
    #     quant_params = dict()
    #     skipped_keys = set() if skipped_keys is None else skipped_keys
    #     legacy_keys = {"qp_in", "qp_out", "limvals_in", "limvals_out"}
    #     skipped_keys = legacy_keys | skipped_keys
    #     for qnpz_key in qnpz:
    #         if qnpz_key.startswith("stats/") or qnpz_key.startswith("layer_params/"):
    #             quant_params[qnpz_key] = qnpz[qnpz_key]
    #         elif qnpz_key in skipped_keys:
    #             pass
    #         else:
    #             self._match_op_qnpz_to_quant_relations(qnpz[qnpz_key], qnpz_key, quant_params)
    #     self.import_stats(quant_params)
    #     self._import_quant_internal(quant_params)
    #     self._validate_io_encoding_params(qnpz)

    def export_stats(self):
        """
        Export the input, output, and preact stats of the layer.
        """
        stats = dict()
        stats.update(self._export_input_stats())
        stats.update(self._export_preact_stats())
        stats.update(self._export_output_stats())
        if self._original_output_stats is not None:
            stats.update(self._export_original_output_stats())
        return stats

    def _export_input_stats(self):
        stats = dict()
        for layer_inp_index, (op, op_inp_index) in enumerate(self._input_stats_ops()):
            if op.stats_collection_state != StatsState.COMPLETE:
                self._logger.warning(f"Failed exporting input stats to layer {self.full_name}")
                continue
            inp_stats = op.get_input_stats(op_inp_index)
            for cfg_key in op.stats_cfg:
                value = getattr(inp_stats, cfg_key.value)
                stats[f"stats/input_{layer_inp_index}/{cfg_key.value}"] = value
            stats[f"stats/input_{layer_inp_index}/stats_limvals"] = op.get_input_limvals(op_inp_index)
        return stats

    def _export_preact_stats(self):
        stats = dict()
        if self.activation_atomic_op is not None:
            # assumes single activation, and single input
            op = self.activation_atomic_op
            if op.stats_collection_state != StatsState.COMPLETE:
                self._logger.warning(f"Failed exporting preact stats to layer {self.full_name}")
            else:
                preact_stats = op.get_input_stats(0)
                for cfg_key in op.stats_cfg:
                    value = getattr(preact_stats, cfg_key.value)
                    stats[f"stats/preact/{cfg_key.value}"] = value
                stats["stats/preact/stats_limvals"] = op.get_input_limvals(0)
        return stats

    def _export_output_stats(self):
        stats = dict()
        for layer_out_index, (op, op_out_index) in enumerate(self._output_stats_ops()):
            if op.stats_collection_state != StatsState.COMPLETE:
                self._logger.warning(f"Failed exporting output stats to layer {self.full_name}")
                continue
            out_stats = op.get_output_stats(op_out_index)  # op_out_index is currently ignored, assumes 1 output
            for cfg_key in op.stats_cfg:
                value = getattr(out_stats, cfg_key.value)
                stats[f"stats/output_{layer_out_index}/{cfg_key.value}"] = value
            stats[f"stats/output_{layer_out_index}/stats_limvals"] = op.get_output_limvals(op_out_index)
        return stats

    def _export_original_output_stats(self):
        stats = dict()
        for layer_out_index, out_stats in enumerate(self._original_output_stats):
            for cfg_key in out_stats._fields:
                value = getattr(out_stats, cfg_key)
                if value is not None:
                    stats[f"stats/original_output_{layer_out_index}/{cfg_key}"] = value
            stats[f"stats/original_output_{layer_out_index}/stats_limvals"] = (
                np.min(out_stats.min),
                np.max(out_stats.max),
            )
        return stats

    def import_stats(self, params):
        stats_key = "stats/"
        stats = {k[len(stats_key) :]: v for k, v in params.items() if k.startswith(stats_key)}
        self._import_input_stats(stats)
        self._import_preact_stats(stats)
        self._import_output_stats(stats)
        self._import_original_output_stats(stats)

    def _import_input_stats(self, stats):
        for layer_inp_index, (op, op_inp_index) in enumerate(self._input_stats_ops()):
            inp_key = f"input_{layer_inp_index}/"
            inp_stats = {
                k[len(inp_key) :]: v
                for k, v in stats.items()
                if k.startswith(inp_key) and not k.endswith("/stats_limvals")
            }
            op.import_input_stats(inp_stats, op_inp_index)

    def _import_preact_stats(self, stats):
        if self.activation_atomic_op is not None:
            op = self.activation_atomic_op
            preact_key = "preact/"
            preact_stats = {
                k[len(preact_key) :]: v
                for k, v in stats.items()
                if k.startswith(preact_key) and not k.endswith("/stats_limvals")
            }
            op.import_input_stats(preact_stats, 0)

    def _import_output_stats(self, stats):
        for layer_out_index, (op, op_out_index) in enumerate(self._output_stats_ops()):
            out_key = f"output_{layer_out_index}/"
            out_stats = {
                k[len(out_key) :]: v
                for k, v in stats.items()
                if k.startswith(out_key) and not k.endswith("/stats_limvals")
            }
            op.import_output_stats(out_stats, op_out_index)

    def _import_original_output_stats(self, stats):
        if any(k.startswith("original_output_") for k in stats.keys()):
            self._original_output_stats = []
            for layer_out_index in range(self.num_outputs):
                out_key = f"original_output_{layer_out_index}/"
                out_stats = {
                    k[len(out_key) :]: v
                    for k, v in stats.items()
                    if k.startswith(out_key) and not k.endswith("/stats_limvals")
                }
                self._original_output_stats.append(ImportedStats(out_stats).get())
        else:
            self._original_output_stats = None

    def _change_list_to_np_array(self, list_vectors):
        """
        create from a list of scales/zp that are not in the same size, an np.array matrix that all rows are the same size
        each of each will be of the max vector size.  Short vectors will be padded by nan values.
        """
        if len(list_vectors) == 1:
            return np.array(list_vectors, np.float32)

        numpy_vectors = []

        for v in list_vectors:
            # Convert keras/TF variables to NumPy arrays eagerly
            if hasattr(v, "numpy"):
                numpy_vectors.append(v.numpy().astype(np.float32))
            else:
                numpy_vectors.append(np.array(v, dtype=np.float32))

        pad_token = np.nan
        return np.array(list(zip(*itertools.zip_longest(*numpy_vectors, fillvalue=pad_token))), np.float32)

    def _update_scalar_to_scale(self, zero_points, scales):
        vectorize_zero_points = []
        for i, zero_point_scalar in enumerate(zero_points):
            scale = scales[i]
            if np.array(zero_point_scalar).shape != () and np.array(zero_point_scalar).shape[0] != 1:
                # remove exception for zp scalar
                vectorize_zero_point = zero_point_scalar
            else:
                if np.array(scale).shape == ():
                    len_scale = 1
                else:
                    len_scale = np.array(scale).shape[0]
                vectorize_zero_point = np.repeat(zero_point_scalar, len_scale)
            vectorize_zero_points.append(vectorize_zero_point)
        return vectorize_zero_points

    def import_layer_params(self, params):
        self._negative_slope_exponent_fix_shift = params["layer_params/negative_slopes_correction_factor"]
        if "layer_params/eq_vec_out" not in params:
            self._logger.warning("eq_vec_out is not in the params")
        else:
            self.eq_vec_out = params["layer_params/eq_vec_out"]

    def _create_out_in_scale_ratio(self):
        pass

    def create_hw_params(
        self, weights_clipping: LayerWeightsClippingConfig, optimization_target: OptimizationTarget, hw_shifts=None
    ):
        """
        Override whenever numerization of atomic ops depend on each-other's numerization
        (almost any multi-op layer..)
        """
        raise AccelerasImplementationError(f"create_hw_params is not implemented for {self.full_name}")

    def iterate_input_ops(self) -> Iterator[Tuple[BaseAtomicOp, int]]:
        input_nodes = self._layer_flow.get_inputs()
        for in_node in input_nodes:
            successors = self._layer_flow.successors_sorted(in_node)
            for op_name in successors:
                edge = self._layer_flow.get_edge_data(in_node, op_name)
                input_index = edge["input_index"]
                op = self._layer_flow.get_op(op_name)
                yield op, input_index

    def iterate_output_ops(self) -> Iterator[Tuple[BaseAtomicOp, int]]:
        output_nodes = self._layer_flow.get_outputs()
        for out_node in output_nodes:
            successors = self._layer_flow.predecessors_sorted(out_node)
            for op_name in successors:
                edge = self._layer_flow.get_edge_data(op_name, out_node)
                output_index = edge["output_index"]
                op = self._layer_flow.get_op(op_name)
                yield op, output_index

    def _iterate_act_ops(self):
        for op in self.atomic_ops:
            if isinstance(op, ActivationOp):
                yield op

    def get_input_limvals(self) -> List[tuple]:
        """
        returns a list of the min-max values for all the basic_composite_op inputs (it may be more than one)
        """
        limvals = []
        for op, input_index in self._input_stats_ops():
            limvals.append(op.get_input_limvals(input_index))
        return limvals

    def get_output_limvals(self) -> List[tuple]:
        """
        returns a list of the min-max values for all the basic_composite_op outputs (not is only one)
        """
        limvals = []
        for op, output_index in self._output_stats_ops():
            limvals.append(op.get_output_limvals(output_index))
        return limvals

    def get_group_input_limvals(self, groups: int) -> List[tuple]:
        """
        List of Lim vals by group
        """
        limvals = []
        for op, input_index in self._input_stats_ops():
            limvals.append(op.get_group_input_limvals(input_index, groups))
        return limvals

    def get_group_output_limvals(self, groups: int) -> List[tuple]:
        """
        List of Lim vals by group
        """
        limvals = []
        for op, output_index in self._output_stats_ops():
            limvals.append(op.get_group_output_limvals(output_index, groups))
        return limvals

    def get_original_output_limvals(self) -> List[tuple]:
        """
        returns a list of the original min-max values for all the basic_composite_op outputs (before activation clipping)
        if original output stats wasn't define, return last known output limvals.
        """
        if self._original_output_stats is None:
            return self.get_output_limvals()
        limvals = []
        for out_stats in self._original_output_stats:
            limvals.append((np.min(out_stats.min), np.max(out_stats.max)))
        return limvals

    def get_input_lossy_elements(self) -> List[QuantElement]:
        """
        returns a list of the lossy_elements
        """
        lossy_elements = []
        for op, input_index in self.iterate_input_ops():
            lossy_elements.append(op.input_lossy_elements[input_index])
        return lossy_elements

    def get_weight_lossy_elements(self) -> List[QuantElement]:
        """
        returns a list of the lossy_elements
        """
        lossy_elements = []
        for op in self.atomic_ops:
            lossy_elements.append(op.weight_lossy_elements)
        return lossy_elements

    def get_output_lossy_elements(self) -> List[QuantElement]:
        """
        returns a list of the lossy_elements
        """
        lossy_elements = []
        for op, output_index in self.iterate_output_ops():
            lossy_elements.append(op.output_lossy_elements[output_index])
        return lossy_elements

    def get_input_stats(self) -> List[Statistics]:
        """
        returns a list of all basic statistics for basic_composite_op inputs (it may be more than one)
        """
        stats = []
        for op, input_index in self._input_stats_ops():
            stats.append(op.get_input_stats(input_index))
        return stats

    def get_output_stats(self) -> List[Statistics]:
        """
        returns a list of all basic statistics for basic_composite_op inputs (it may be more than one)
        """
        stats = []
        for op, output_index in self._output_stats_ops():
            stats.append(op.get_output_stats(output_index))
        return stats

    def keep_original_output_stats(self):
        if self._original_output_stats is None:
            self._original_output_stats = deepcopy(self.get_output_stats())

    # TODO consider removing these - not for every composite op they make sense
    def get_preact_limvals(self) -> List[tuple]:
        """
        returns a list of the min-max values for all the basic_composite_op pre activation (not is only one)
        """
        limvals = []
        for act_op in self._iterate_act_ops():
            limvals.append(act_op.get_input_limvals(0))
        if len(limvals) == 0:
            limvals = self.get_output_limvals()
        return limvals

    def get_preact_stats(self) -> List[Statistics]:
        stats = []
        for act_op in self._iterate_act_ops():
            stats.append(act_op.get_input_stats(0))
        if len(stats) == 0:
            stats = self.get_output_stats()
        return stats

    def get_weights_clipping(self) -> Optional[tuple]:
        """
        return clipping values of the layer's kernel
        """

    @property
    def activation_atomic_op(self) -> Optional[ActivationOp]:
        ops = [op for op in self._iterate_act_ops()]
        if len(ops) == 0:
            return None
        elif len(ops) == 1:
            return ops[0]
        else:
            raise RuntimeError(
                f"Layer {self.full_name} has {len(ops)} activations which is not supported (expected up to 1)",
            )

    @property
    def is_changing_bias_supported(self):
        return False

    def import_native_bias(self, native_bias):
        raise AccelerasImplementationError(f"import_native_bias is not supported for {self.full_name}")

    def export_native_bias(self):
        raise AccelerasImplementationError(f"export_native_bias is not supported for {self.full_name}")

    @abstractmethod
    def import_weights(self, layer_params: LayerParams, **kwargs):
        pass

    def import_translation_config(
        self,
        translation_config: LayerTranslationConfig,
    ):
        """
        load layer's precision config and translation config.
        the precision config affects the quant elements of the ops in the layer
        The translation config is misc config for the quantization
        Args:
            precision_config - LayerPrecisionConfig object the describes the bias_mode,
                                precision_mode, and quantization groups
            translation_config - LayerTranslationConfig, describes various quantization config
        """
        for op in self.atomic_ops:
            if isinstance(op, AddBiasOp):
                op.set_max_feed_repeat(translation_config.max_bias_feed_repeat)
            elif isinstance(op, ActivationOp):
                op.set_fit_policy(translation_config.activation_fit)
            elif isinstance(op, ElementwiseAddOp):
                op.set_max_feed_repeat(translation_config.max_elementwise_feed_repeat)
            op.set_ignore_hw_limitation_assertion(translation_config.ignore_hw_limitation_assertion)

    def import_precision_config(self, precision_config: LayerPrecisionConfig, optimization_target: OptimizationTarget):
        precision_mode = precision_config.precision_mode
        accumulator_bits = get_accumulator_bits_by_precision_mode(precision_mode)
        self.create_quant_element_by_data_path(DataPath.ACCUMULATOR, accumulator_bits)
        input_bits = get_input_bits_by_precision_mode(precision_mode)
        self.create_quant_element_by_data_path(DataPath.LAYER_IN, input_bits)
        self.create_quant_element_by_data_path(DataPath.LAYER_IN_WEIGHTS, input_bits)
        output_bits = get_output_bits_by_precision_mode(precision_mode)
        if output_bits is not None:
            self.create_quant_element_by_data_path(DataPath.LAYER_OUT, output_bits)
            self.create_quant_element_by_data_path(DataPath.LAYER_OUT_WEIGHTS, output_bits)
            self.create_quant_element_custom_behavior(precision_config, optimization_target)

    def verify_config(
        self,
        precision_config: LayerPrecisionConfig,
    ):
        self._verify_supported_precision_cfg_in_acceleras(precision_config)
        self._verify_precision_config(precision_config)

    def create_quant_element_custom_behavior(
        self, precision_config: LayerPrecisionConfig, optimization_target: OptimizationTarget, **kwargs
    ):
        """
        In this function the layer should initialize the quant element for the weights data_path or for or unique
        data_paths.

        For example - conv and bias weights, activation data (offset + mant), or ew_mult_op's data_path in the apu.
        By default - the function makes sure there's no additional data_path in the flow and that there are no
        uninitialized weights lossy elements
        """
        # TODO: consider adding verification that the layer has no weights / no addtional data_paths
        for u, v in self._layer_flow.edges:
            data_path = self._layer_flow.get_edge_data_path(u, v)
            if data_path not in {
                DataPath.LAYER_IN,
                DataPath.LAYER_OUT,
                DataPath.ACCUMULATOR,
                DataPath.LAYER_IN_WEIGHTS,
                DataPath.LAYER_OUT_WEIGHTS,
            }:
                raise AccelerasImplementationError(f"Layer {self.full_name} has unhandled data_path {data_path.value}")
        for op in self.atomic_ops:
            if not isinstance(op.weight_lossy_elements, BaseWeightLossyElements):
                raise AccelerasImplementationError(
                    f"Layer {self.full_name} has unhandled lossy_elements {data_path.value}"
                )

    def create_quant_element_by_data_path(self, data_path, bits):
        """Takes precision config of the layers and sets the correct quant elements to the layers"""
        input_nodes = self._layer_flow.get_inputs()
        output_nodes = self._layer_flow.get_outputs()
        for u, v in self._layer_flow.edges:
            curr_data_path = self._layer_flow.get_edge_data_path(u, v)
            if curr_data_path != data_path:
                continue
            u_bn = os.path.basename(u)
            v_bn = os.path.basename(v)
            qe_name = os.path.join(self.full_name, f"qe:{u_bn}-{v_bn}")
            quant_element = self._create_quant_element(data_path, bits, qe_name)
            if u not in input_nodes:
                u_op = self._layer_flow.get_op(u)
                output_index = self._layer_flow.get_edge_output_index(u, v)
                u_op.set_output_lossy_element(quant_element, index=output_index)
            if v not in output_nodes:
                v_op = self._layer_flow.get_op(v)
                input_index = self._layer_flow.get_edge_input_index(u, v)
                v_op.set_input_lossy_element(quant_element, index=input_index)

    def _create_quant_element(self, data_path: DataPath, bits: int, name: str):
        quant_element_cls = get_quant_element_by_data_path(data_path, bits)
        return quant_element_cls(bits, name=name)

    def get_inputs_data_path(self):
        input_nodes = self._layer_flow.get_inputs()
        inputs_data_path = [None for _ in input_nodes]
        for u, v in self._layer_flow.edges:
            if u not in input_nodes:
                continue
            curr_data_path = self._layer_flow.get_edge_data_path(u, v)
            input_index = input_nodes.index(u)
            inputs_data_path[input_index] = curr_data_path
        return inputs_data_path

    def set_input_data_path(self, data_path, input_index=0):
        input_nodes = self._layer_flow.get_inputs()
        nodes_to_check = [input_nodes[input_index]]
        while len(nodes_to_check) > 0:
            node = nodes_to_check.pop()
            for succ in self._layer_flow.successors_sorted(node):
                curr_data_path = self._layer_flow.get_edge_data_path(node, succ)
                if curr_data_path in [DataPath.LAYER_IN, DataPath.LAYER_IN_WEIGHTS]:
                    self._layer_flow.edges[node, succ]["data_path"] = data_path
                    nodes_to_check.append(succ)

    def set_output_data_path(self, data_path):
        for u, v in self._layer_flow.edges:
            curr_data_path = self._layer_flow.get_edge_data_path(u, v)
            if curr_data_path in [DataPath.LAYER_OUT, DataPath.LAYER_OUT_WEIGHTS]:
                self._layer_flow.edges[u, v]["data_path"] = data_path

    def _verify_precision_config(
        self,
        precision_config: LayerPrecisionConfig,
    ):
        """
        verify the precision config of the layer is ok - we have two kinds of errors:
        1. if we didn't yet support this configuration in acceleras- then we have AccelerasImplementationError
        2. if we don't support the configuration at all - (hw constraints) then we have AccelerasHWUnsupportedError
        this function relates to the latter
        Args:
            precision_config:class:
            hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer.LayerPrecisionConfig

        """
        supported_precision = self.SUPPORTED_PRECISION_MODE
        supported_bias = self.SUPPORTED_BIAS_MODE
        self._verify_single_precision_field(precision_config, "precision_mode", supported_precision)
        self._verify_single_precision_field(precision_config, "bias_mode", supported_bias)
        self._layer_precision_config = precision_config
        self._verify_quantization_groups(precision_config)

    def _verify_single_precision_field(self, precision_config, field, supported_values):
        given_value = getattr(precision_config, field)
        if given_value not in supported_values:
            raise AccelerasImplementationError(
                f"{self.full_name} does not support {given_value} {field}. Supported values are: {supported_values}",
            )

    def _verify_quantization_groups(self, precision_config: LayerPrecisionConfig):
        quantization_groups = precision_config.quantization_groups
        if quantization_groups != 1 and not self.SUPPORTED_QUANTIZATION_GROUPS:
            raise AccelerasImplementationError(
                f"{self.full_name} does not support quantization_groups. Expected 1, but got {quantization_groups}.",
            )

    @property
    def is_numerized(self):
        return self._has_hw_params

    @property
    def input_shapes(self):
        input_shapes = []
        for op, in_index in self.iterate_input_ops():
            if op.output_shapes_is_valid:
                in_shape = op.input_shapes[in_index]
            else:
                in_shape = self._hn_element["input_shapes"][in_index]
                in_shape = [None, *in_shape[1:]]
            in_shape = self._convert_to_rank4(in_shape)
            input_shapes.append(in_shape)
        return input_shapes

    @property
    def input_shape(self):
        input_shapes = self.input_shapes
        if len(input_shapes) != 1:
            raise AttributeError(
                f"The layer has multiple inputs. "
                f'Hence the notion of "input_shape" is '
                f"ill-defined for the layer - {self.full_name}",
            )
        return input_shapes[0]

    @property
    def output_shapes(self):
        output_shapes = []
        for op, out_index in self.iterate_output_ops():
            if op.output_shapes_is_valid:
                out_shape = op.output_shapes[out_index]
            else:
                out_shape = self._hn_element["output_shapes"][out_index]
                out_shape = [None, *out_shape[1:]]
            out_shape = self._convert_to_rank4(out_shape)
            output_shapes.append(out_shape)
        return output_shapes

    @property
    def output_shape(self):
        output_shapes = self.output_shapes
        if len(output_shapes) != 1:
            raise AttributeError(
                f"The layer has multiple outputs. "
                f'Hence the notion of "input_shape" is '
                f"ill-defined for the layer - {self.full_name}",
            )
        return output_shapes[0]

    @staticmethod
    def _get_dim(input_shape):
        dim = 0
        val = input_shape
        for i in range(100):
            if isinstance(val, int) or val is None:
                return dim
            else:
                dim += 1
                val = val[0]
        raise AccelerasValueError("input shape values should have ints")

    def build(self, input_shape):
        if isinstance(input_shape, dict):
            input_shape = [v for v in input_shape.values()]
        dims = self._get_dim(input_shape)
        if dims > 1:
            self.verify_layer_inputs_shape(input_shape)
        else:
            self.verify_layer_inputs_shape([input_shape])
        if dims > 1 and self.num_inputs == 1:
            input_shape = input_shape[0]
        self.compute_output_shape(input_shape)
        if not self.ignore_io_shapes_verification:
            self._verify_and_set_hn_io_shapes()
        for op in self.atomic_ops:
            if not op.built:
                op.build(op.input_shapes)
        self._build(input_shape)
        if self.infer_in_build:
            self.enforce_internal_encoding()
        self._build_input_shape = input_shape
        self.built = True

    def _build(self, input_shapes):
        pass

    def compute_output_shape(self, input_shapes):
        # set input ops input shapes
        input_nodes = self._layer_flow.get_inputs()
        if len(input_nodes) == 1:
            all_ops_output_shapes = {input_nodes[0]: input_shapes}
        else:
            if len(input_shapes) != len(input_nodes):
                raise ValueError(
                    f"Inputs and input nodes not the same length in layer {self.full_name} - "
                    f"inputs: {len(input_shapes)}, nodes: {len(input_nodes)}",
                )
            all_ops_output_shapes = {node: data for node, data in zip(input_nodes, input_shapes)}
        # calculate the shapes of all the ops
        for op_name in self._layer_flow.toposort_ops():
            preds = self._layer_flow.predecessors_sorted(op_name)
            op = self._layer_flow.get_op(op_name)
            op_input_shapes = [all_ops_output_shapes[pred] for pred in preds]
            if len(preds) == 1:
                op_input_shapes = op_input_shapes[0]
            op_output_shape = op.compute_output_shape(op_input_shapes)
            all_ops_output_shapes[op_name] = op_output_shape
        # set output shapes
        results = []
        output_nodes = self._layer_flow.get_outputs()
        for output_node in output_nodes:
            preds = list(self._layer_flow.predecessors(output_node))
            if len(preds) != 1:
                raise RuntimeError(
                    f"Unexpected output predecessors count in layer {self.full_name}, predecessors: {preds}"
                )
            pred = preds[0]
            results.append(all_ops_output_shapes[pred])

        if len(results) == 1:
            return results[0]
        return results

    def get_layer_precision_config(self):
        try:
            quantization_params = LayerPrecisionConfig(
                bias_mode=self.get_bias_mode(),
                precision_mode=self.get_precision_mode(),
                quantization_groups=self.get_quantization_groups(),
                signed_output=self.get_signed_output(),
                quantization_weight_groups=self.get_quantization_weight_groups(),
            )
            return quantization_params.raw_dict()
        except AttributeError:
            # TODO: find a proper way to indicate if config was loaded
            return None

    def to_hn(self, out_degree: Optional[int] = None) -> dict:
        if out_degree is None:
            out_degree = self.num_outputs
        out_degree = max(1, out_degree)  # in case the layer has no outputs (output layer)
        hn_dict = self._hn_element
        if "type" not in hn_dict:
            hn_dict["type"] = self._hn_type.value
        if (
            all(op.output_shapes_is_valid for op, _ in self.iterate_input_ops())
            and not self.ignore_io_shapes_verification
        ):
            # We check if the layer has been built because we assume the build called the layer ops.
            # If we ignore shape verification, the model's resolution is reduced, and we don't want to export
            # the reduced shape to hn.
            hn_dict["input_shapes"] = [[int(v) for v in shapes] for shapes in self._get_hn_input_shapes()]
            hn_dict["output_shapes"] = [[int(v) for v in shapes] for shapes in self._get_hn_output_shapes()]
            # TODO: handle by the model
            hn_dict["output_shapes"] = hn_dict["output_shapes"] * int(np.ceil(out_degree / self.num_outputs))
        self._verify_exportable(hn_dict)
        hn = dict()
        quantization_params = self.get_layer_precision_config()
        if quantization_params is not None:
            hn_dict["quantization_params"] = quantization_params

        return update_nested(hn, hn_dict)

    @property
    def outputs_dim(self):
        shapes = self._hn_element.get("output_shapes", self.output_shapes)
        shapes = [self._convert_to_rank4(shape) for shape in shapes]
        return [len(output_shape) for output_shape in shapes]

    @property
    def inputs_dim(self):
        shapes = self._hn_element.get("input_shapes", self.input_shapes)
        shapes = [self._convert_to_rank4(shape) for shape in shapes]
        return [len(input_shape) for input_shape in shapes]

    def verify_layer_inputs_shape(self, input_shapes):
        pass

    def _convert_to_rank4(self, shape):
        if len(shape) == 2:
            return [shape[0], 1, 1, shape[1]]
        return shape

    def _verify_and_set_hn_io_shapes(self):
        input_shapes = self._hn_element.get("input_shapes")
        if input_shapes:
            input_shapes = [self._convert_to_rank4(shape) for shape in input_shapes]
            if not self._verify_hn_to_keras_input_shapes(self._get_hn_input_shapes(), input_shapes):
                raise AccelerasValueError(
                    f"Inference input shapes {self._get_hn_input_shapes()} for layer {self.full_name} "
                    f"does not match HN shapes {input_shapes}",
                )

        output_shapes = self._hn_element.get("output_shapes")
        if output_shapes:
            output_shapes = [self._convert_to_rank4(shape) for shape in output_shapes]
            if (
                len(output_shapes) != len(self._get_hn_output_shapes())
                and np.array([output_shapes[0] == shape for shape in output_shapes]).all()
            ):
                # TODO: support multi output and split https://hailotech.atlassian.net/browse/SDK-38641
                self._hn_output_split = len(output_shapes)
                output_shapes = [output_shapes[0]]
            if not self._verify_hn_to_keras_output_shapes(self._get_hn_output_shapes(), output_shapes):
                raise AccelerasValueError(
                    f"Inference output shapes {self._get_hn_output_shapes()}"
                    f" does not match hn shapes {output_shapes} in {self.full_name}",
                )

    @staticmethod
    def _verify_hn_to_keras_input_shapes(keras_shapes, hn_shapes):
        if len(keras_shapes) != len(hn_shapes):
            return False
        for keras_shape, hn_shape in zip(keras_shapes, hn_shapes):
            if not (np.array(keras_shape) == np.array(hn_shape)).all():
                return False
        return True

    @staticmethod
    def _verify_hn_to_keras_output_shapes(keras_shapes, hn_shapes):
        return BaseHailoLayer._verify_hn_to_keras_input_shapes(keras_shapes, hn_shapes)

    def _get_hn_input_shapes(self):
        res = []
        for index, shape in enumerate(self.input_shapes):
            if self.inputs_dim[index] == 2:
                res.append([-1, 1, 1, shape[-1]])
            elif self.inputs_dim[index] == 4:
                res.append([-1, *list(np.array(shape[1:]))])
            else:
                raise AccelerasValueError(f"inputs_dim of {self.full_name} must be 2 or 4 but got {self.inputs_dim}")
        return res

    def _get_hn_output_shapes(self):
        res = []
        for index, shape in enumerate(self.output_shapes):
            if self.outputs_dim[index] == 2:
                res.append([-1, 1, 1, shape[-1]])
            elif self.outputs_dim[index] == 4:
                res.append([-1, *list(np.array(shape[1:]))])
            else:
                raise AccelerasValueError(f"outputs_dim of {self.full_name} must be 2 or 4 but got {self.outputs_dim}")
        return res

    def _init_flow(self) -> LayerFlow:
        layer_flow = LayerFlow()
        ops = [op for op in self.__dict__.values() if isinstance(op, BaseAtomicOp)]
        for op in ops:
            layer_flow.add_node(op)
        return layer_flow

    def _verify_supported_precision_cfg_in_acceleras(self, precision_config: LayerPrecisionConfig):
        """
        verify the precision_config is supported in acceleras
        Args:
            precision_config:class:
            hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer.LayerPrecisionConfig

        """
        # TODO - we will want to support the rest of the configurations in the near future
        precision_mode = precision_config.precision_mode
        if precision_mode not in self.supported_precision_mode_acceleras:
            raise AccelerasImplementationError(
                f"Acceleras does not support {precision_mode.value} precision_mode (layer {self.full_name})",
            )

        bias_mode = precision_config.bias_mode
        if bias_mode not in self.supported_bias_mode_acceleras:
            raise AccelerasImplementationError(
                f"Acceleras does not support {bias_mode.value} bias_mode (layer {self.full_name})",
            )

    def get_activation_name(self):
        act_op = self.activation_atomic_op
        if act_op is None:
            return None
        return act_op.act_name

    def _visualize_flow(self, dest_path):
        self._layer_flow.visualize(dest_path)

    def get_scalar_vector(self, vector, eps=0):
        """
        a function that checks the vector is a scalar up to eps
        Args:
            vector: vector
            eps: the threshold we allow it to differ

        Returns: the scalar_vector

        """
        if isinstance(vector, (np.float32, float)):
            return vector

        if vector.shape != ():
            self.check_vector_scales(vector, rtol=eps)
            scalar_vector = vector[0]
        else:
            scalar_vector = vector
        return scalar_vector

    def check_vector_scales(self, vector, rtol=1e-5, atol=1e-7):
        if np.array(vector).shape == ():
            return True
        if not np.allclose(vector[0] * np.ones_like(vector), vector, rtol=rtol, atol=atol):
            raise AccelerasNumerizationError(
                f"In layer {self.full_name} the vector must be a scalar but there is a diff "
                f"{np.max(np.abs(vector - vector[0]) / vector)}",
            )

    # region properties for scale matching

    @property
    def vector_zp_in(self):
        """
        when there is a kernel that is trainable and (1,1) we can allow a vector of zero_point input
        """
        return np.any(
            [
                isinstance(atom, ConvStrippedOp)
                and atom.trainable
                and atom.kernel_size == (1, 1)
                and atom.padding_const_value == 0
                for atom in self.atomic_ops
            ],
        )

    @property
    def consumer_input_scale(self):
        """
        when there is a kernel that is trainable we have a vector degree of freedom for the input scale

        """
        return np.any(
            [
                isinstance(atom, ConvStrippedOp) and atom.trainable and atom.padding_const_value == 0
                for atom in self.atomic_ops
            ],
        )

    @property
    def has_activation(self):
        """
        when there is activation we have a potential scalar degree of freedom for the output scale
        Returns:

        """
        return np.any([isinstance(atom, ActivationOp) for atom in self.atomic_ops])

    @property
    def homogeneous(self):
        """
        when there is an homogenous activation we have a scalar degree of freedom for the output scale
        """
        if not self.has_activation:
            # layers with no activation at all will be default be homogeneous
            return True
        return np.any([isinstance(atom, ActivationOp) and atom.homogeneous for atom in self.atomic_ops])

    # endregion

    @property
    def shift_delta(self):
        """
        Return the shift delta needed by this layer. If the layer does not support shift delta or
        the shift delta was not calculated yet, return False
        """
        conv_ops = [op for op in self.atomic_ops if isinstance(op, ConvStrippedOp)]
        if len(conv_ops) != 1:
            return None
        return conv_ops[0].shift_delta

    def update_negative_slope_exponent_shift(self, output_shift):
        output_factor = 2**output_shift
        self._negative_slope_exponent_fix_shift = output_shift
        self.set_output_scale(self.output_scale * output_factor, 0)
        self.activation_atomic_op.output_factor_by_group /= output_factor
        self.activation_atomic_op.update_mantissa_exponent_decomposition()
        self.enforce_internal_encoding()

    # region Precision config

    def is_supported_by_hw(self, optimization_target, lcfg: Optional[LayerPrecisionConfig] = None):
        if lcfg is None:
            try:
                precision_mode = self.get_precision_mode()
                bias_mode = self.get_bias_mode()
                quantization_groups = self.get_quantization_groups()
                signed_output = self.get_signed_output()
            except AttributeError:
                return False
        else:
            precision_mode = lcfg.precision_mode
            bias_mode = lcfg.bias_mode if lcfg.bias_mode else self.get_default_bias_mode()
            quantization_groups = (
                lcfg.quantization_groups if lcfg.quantization_groups else self.get_default_quantization_groups()
            )
            signed_output = (
                lcfg.signed_output if lcfg.signed_output else self.get_default_precision_config().signed_output
            )

        if optimization_target == OptimizationTarget.EMULATION:
            return (
                (precision_mode in self.SUPPORTED_PRECISION_MODE)
                and (bias_mode in self.SUPPORTED_BIAS_MODE)
                and (quantization_groups == 1 or self.SUPPORTED_QUANTIZATION_GROUPS)
            )
        else:
            return (
                precision_mode in self._get_precision_mode_supported_in_hw(optimization_target)
                and bias_mode in self._get_bias_mode_supported_in_hw(optimization_target)
                and self._is_precision_config_supported(precision_mode, bias_mode, optimization_target)
                and self.is_quantization_groups_supported_in_hw(quantization_groups, optimization_target)
                and self.is_signed_output_supported_with_precision_mode(
                    signed_output, precision_mode, optimization_target
                )
            )

    def _is_precision_config_supported(self, precision_mode, bias_mode, optimization_target):
        if precision_mode in {
            PrecisionMode.a16_w16_a16,
            PrecisionMode.a16_w16_a8,
            PrecisionMode.a16_w16,
        } and bias_mode not in {BiasMode.single_scale_decomposition}:
            return False
        if (
            precision_mode in {PrecisionMode.a8_w4_a8, PrecisionMode.a8_w4_a16, PrecisionMode.a8_w4}
            and bias_mode != BiasMode.double_scale_initialization
        ):
            return False
        return True

    def is_quantization_groups_supported_in_hw(self, quantization_groups, optimization_target):
        if quantization_groups == -1:
            if not self.has_activation:
                return False
            quantization_groups = self.activation_atomic_op.num_of_channels
        if quantization_groups == 1:
            return True
        if optimization_target in {OptimizationTarget.EMULATION}:
            return self.SUPPORTED_QUANTIZATION_GROUPS
        else:
            return self.SUPPORTED_QUANTIZATION_GROUPS and self._supported_quantization_groups_hw(
                quantization_groups,
                optimization_target,
            )

    def is_signed_output_supported_with_precision_mode(
        self, signed_output: bool, precision_mode: PrecisionMode, optimization_target
    ):
        if (
            optimization_target in {OptimizationTarget.EMULATION}
            or not signed_output
            or not precision_mode.has_output_bits()
        ):
            return True
        return precision_mode.output_bits() == 8 and precision_mode.weight_bits() != 16

    def _supported_quantization_groups_hw(self, quantization_groups, optimization_target):
        return False

    def _get_bias_mode_supported_in_hw(self, optimization_target):
        if optimization_target in {OptimizationTarget.MERCURY, OptimizationTarget.SAGE, OptimizationTarget.PLUTO}:
            # by default same as emulator support
            return self.SUPPORTED_BIAS_MODE - {BiasMode.double_scale_decomposition}
        elif optimization_target in {OptimizationTarget.MARS}:
            return self.SUPPORTED_BIAS_MODE
        elif optimization_target in {OptimizationTarget.EMULATION}:
            return self.SUPPORTED_BIAS_MODE
        else:
            return set()

    def _get_precision_mode_supported_in_hw(self, optimization_target):
        if optimization_target in {OptimizationTarget.MERCURY, OptimizationTarget.SAGE, OptimizationTarget.PLUTO}:
            supported_precision_mode = self.SUPPORTED_PRECISION_MODE  # by default same as emulator support
        elif optimization_target in {OptimizationTarget.MARS}:
            supported_precision_mode = self.SUPPORTED_PRECISION_MODE
        elif optimization_target in {OptimizationTarget.EMULATION}:
            supported_precision_mode = self.SUPPORTED_PRECISION_MODE
        else:
            supported_precision_mode = set()
        return supported_precision_mode

    def get_macs(self) -> int:
        """
        This method should return the number of mac operations
        of a layer, default value will be the product of the input shapes (number of elements)
        """
        macs = 0
        for input_shape in self.input_shapes:
            macs += np.prod(input_shape[1:])
        return macs

    def get_bops(self) -> int:
        """
        This method returns the number of mac operations
        times the Number of bits factor
        """
        precision = self.get_precision_mode()

        if precision.reduce() is PrecisionMode.a8_w4:
            factor = 0.5
        elif precision.reduce() is PrecisionMode.a8_w8:
            factor = 1.0
        elif precision.reduce() is PrecisionMode.a16_w16:
            factor = 4.0
        elif precision is PrecisionMode.a16_w8_a8:
            factor = 1.0
        else:
            factor = 1.0
            raise AccelerasImplementationError(f"Layer {self.full_name} has unsupported precision mode {precision}")
        return int(self.get_macs() * factor)

    def get_precision_mode(self) -> PrecisionMode:
        inp_bits = np.unique([inp_lossy_elem.bits for inp_lossy_elem in self.get_input_lossy_elements()])
        out_bits = np.unique([out_lossy_elem.bits for out_lossy_elem in self.get_output_lossy_elements()])
        if len(inp_bits) != 1:
            raise AccelerasNumerizationError(f"Inconsistent input bits values in layer {self.full_name}: {inp_bits}")
        if len(out_bits) != 1:
            raise AccelerasNumerizationError(f"Inconsistent output bits values in layer {self.full_name}: {out_bits}")
        inp_bits = inp_bits[0]
        out_bits = out_bits[0]

        if inp_bits == 15:
            inp_bits = 16
        kernel_bits = self._get_kernel_bits()
        if kernel_bits is None:
            kernel_bits = inp_bits
        if kernel_bits == 15:
            kernel_bits = 16
        if out_bits == 15:
            out_bits = 16
        precision_mode = PrecisionMode(f"a{inp_bits}_w{kernel_bits}_a{out_bits}")
        return precision_mode

    def get_bias_mode(self) -> BiasMode:
        bias_ops = self.get_bias_ops()
        num_decompositions = []
        for op in bias_ops:
            decomp_count = op.weight_lossy_elements.bias_decompose.num_decomposition
            num_decompositions.append(decomp_count)
        num_decompositions = np.unique(num_decompositions)
        if len(num_decompositions) > 1:
            raise AccelerasNumerizationError(
                f"Inconsistent num_decompositions values in layer {self.full_name}: {num_decompositions}",
            )
        elif len(num_decompositions) == 0:
            return BiasMode.single_scale_decomposition
        else:
            num_decompositions = num_decompositions[0]
            if num_decompositions == 0:
                return BiasMode.double_scale_initialization
            elif num_decompositions == 1:
                return BiasMode.single_scale_decomposition
            elif num_decompositions == 2:
                return BiasMode.double_scale_decomposition
            else:
                raise AccelerasNumerizationError(
                    f"Invalid bias decomposition value in layer {self.full_name}: {num_decompositions}",
                )

    def get_bias_ops(self):
        """this is overriden by decomposed layer to return the correct bias ops"""
        return filter(lambda op: isinstance(op, AddBiasOp), self.atomic_ops)

    def get_quantization_groups(self) -> int:
        if not self.SUPPORTED_QUANTIZATION_GROUPS:
            return 1
        else:
            return self.activation_atomic_op.quantization_groups_num

    def get_quantization_weight_groups(self) -> int:
        if not getattr(self, "SUPPORTED_QUANTIZATION_WEIGHT_GROUPS", False):
            return 1
        return self.quantization_weight_groups

    def get_signed_output(self) -> bool:
        for u, v in self._layer_flow.edges:
            data_path = self._layer_flow.get_edge_data_path(u, v)
            if data_path == DataPath.LAYER_OUT_WEIGHTS and self.get_precision_mode().output_bits() == 8:
                return True
        return False

    def _get_kernel_bits(self) -> Optional[int]:
        return None

    # endregion

    @classmethod
    def get_default_bias_mode(cls):
        if BiasMode.single_scale_decomposition in cls.SUPPORTED_BIAS_MODE:
            return BiasMode.single_scale_decomposition
        else:
            return BiasMode.double_scale_initialization

    # region encoding flow

    def get_encoding_flow(self):
        """
        return encoding flow graph with the layer's encodings, and their respected constraints.
        """
        flow = EncodingFlowGraph()
        for op in self.atomic_ops:
            flow.update(op.get_encoding_flow())

        enc = EncodingSubOp(flow)

        self.define_encodings(flow)
        self.define_constraints(enc)

        return flow

    def define_encodings(self, flow: EncodingFlowGraph):
        """
        Define the encoding nodes of the layer.

        Encoding names should look like '<layer_name>/<key_name>:<index>'. If an equivalent value exist in
        import_quant/export_quant functions, then key_name should match.

        Args:
            flow (EncodingFlowGraph): base encoding flow graph to add the layer's encodings to

        """

    def define_constraints(self, enc: EncodingSubOp):
        """
        Define the constraints between the encoding nodes.

        Args:
            enc (EncodingSubOp): atomic constraints to define relations between the encoding nodes

        """
        for u, v, data in self._layer_flow.edges(data=True):
            if self._layer_flow.nodes[u].get("is_input") or self._layer_flow.nodes[v].get("is_output"):
                continue
            if self._layer_flow.get_op(u).encoding_const and self._layer_flow.get_op(v).encoding_const:
                continue
            u_scale_encoding_output = f"{self._layer_flow.get_op(u).full_name}/output_scale:{data['output_index']}"
            v_scale_encoding_input = f"{self._layer_flow.get_op(v).full_name}/input_scale:{data['input_index']}"
            u_zero_point_encoding_output = (
                f"{self._layer_flow.get_op(u).full_name}/output_zero_point:{data['output_index']}"
            )
            v_zero_point_encoding_input = (
                f"{self._layer_flow.get_op(v).full_name}/input_zero_point:{data['input_index']}"
            )
            enc.identity(u_scale_encoding_output, v_scale_encoding_input)
            enc.identity(u_zero_point_encoding_output, v_zero_point_encoding_input)

    def update_encoding(self, encodings):
        """
        Update the layer's encodings.

        Args:
            encodings (dict): A dictionary of the form '<layer_name>/<key_name>:<index>': encoding_value.

        """

    def enable_force_pruning(self):
        pass

    def disable_force_pruning(self):
        pass

    # endregion

    def get_quant_element_train_mode(self):
        train_modes = {}
        for op in self.atomic_ops:
            for elem in itertools.chain(
                op.input_lossy_elements,
                op.output_lossy_elements,
                op.weight_lossy_elements.__dict__.values(),
            ):
                if isinstance(elem, BaseQuantElement):
                    train_modes[elem] = elem.train_mode
        return train_modes

    def update_quant_elements_train_mode(self, train_mode):
        for op in self.atomic_ops:
            for elem in itertools.chain(
                op.input_lossy_elements,
                op.output_lossy_elements,
                op.weight_lossy_elements.__dict__.values(),
            ):
                if isinstance(elem, BaseQuantElement):
                    elem.train_mode = train_mode

    def set_wraparound_loss(self, state):
        for op in self.atomic_ops:
            for elem in itertools.chain(op.input_lossy_elements, op.output_lossy_elements):
                if isinstance(elem, AccumulatorQuantElement):
                    if state:
                        elem.enable_wraparound_loss()
                    else:
                        elem.disable_wraparound_loss()
            if isinstance(op, AddBiasOp):
                elem = op.weight_lossy_elements.bias_decompose
                if state:
                    elem.enable_wraparound_loss()
                else:
                    elem.disable_wraparound_loss()

    def update_eq_vec_out(self, padded_factors):
        if self.eq_vec_out is None:
            self.eq_vec_out = np.ones_like(padded_factors)
        self.eq_vec_out *= padded_factors
