"""
Some general notes on conventions, terms and design choices:

1. We follow Keras terminology (e.g. add_weights() API) that any variable is a "weights",
with "kernel", "bias" are subtypes therein (we then add more of ours)

2. We separate between :
(A) raw lossy ops (e.g. unscaled clip & round in scalar quantization)
(B) the lossless part of the tensor en(de)coding (e.g. affine transforms)

(!!) To differentiate, we denote them as:
(A) "Bit Reduction" and..
(B) "[de]Numerization",
respectively. This facilitates "native" mode by just skipping the lossy parts,
while also facilitates tween graph comparison by keeping tensors in different modes at the same scales.

3. ?
"""

import copy
import itertools
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_op import BaseOp
from hailo_model_optimization.acceleras.encoding.encoding_data import EncodingType
from hailo_model_optimization.acceleras.encoding.encoding_flow import EncodingFlowGraph
from hailo_model_optimization.acceleras.encoding.encoding_layer import TensorInitializer
from hailo_model_optimization.acceleras.encoding.encoding_sub_ops import EncodingSubOp
from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import BaseQuantElement
from hailo_model_optimization.acceleras.statistics.statistics_base import STATS_TYPE_FLOAT, BasicTypeTuple, TypeStats
from hailo_model_optimization.acceleras.statistics.statistics_factory import ImportedStats, Statistics, StatsManager
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    EmulationType,
    FlowState,
    LayerFeaturePolicy,
    StatsState,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImportParamConfigMismatch,
    AccelerasInitializationError,
    AccelerasNumerizationError,
    AccelerasPrematureQuantOperation,
    AccelerasValueError,
    InvalidInputShape,
)
from hailo_model_optimization.acceleras.utils.flow_state_utils import AtomicOpState
from hailo_model_optimization.acceleras.utils.hsim_wrapper import HSimWrapper
from hailo_model_optimization.acceleras.utils.opt_utils import update_scale, verify_data_dtype
from hailo_model_optimization.acceleras.utils.to_qnpz_utils import qp_to_limvals


@dataclass
class BaseWeightLossyElements:
    pass


class BaseAtomicOp(BaseOp, ABC):
    """
    AtomicOp provides common functionality to enable numeric emulation of core operation.

    The AtomicOp offers 2 inference modes:
    1. Native mode [WIP]
    2. Emulation mode  ; TODO: rename this mode

    The emulation mode is composed of 5 stages:
    1. Numerize all inputs
    2. Inputs' accuracy reduction / bit reduction
    3. Core operation
    4. Output's accuracy reduction / bit reduction
    5. Denumerize output

    The native mode is composed only of the 3rd stage.

    Provides some common functionality for all or some Atomic Ops,
    as well as interface for them to implement (_call_int, etc.)

    Additionally, AtomicOp offers API for stats collection but the layer has to be inference eagerly
    TODO: add keras' dynamic argument to force eager layer
    """

    # Debug tensors sampled in eager mode when debug_mode=True, enabling debug and fine-grained analysis tools.
    _inputs_sample: tf.Tensor
    _inputs_num_sample: tf.Tensor
    _inputs_num_lossy_sample: tf.Tensor
    _output_num_sample: tf.Tensor
    _output_num_lossy_sample: tf.Tensor
    _output_sample: tf.Tensor

    groups: int

    def __init__(self, name, logger=None, fully_native=None, bit_exact=False, **kwargs):
        """
        Args:
            fully_native: the native flow of an atomic op - skips the numerization part
            logger:
            kwargs: arguments forwarded to keras.layers.Layer constructor: name, trainable

        """
        default_fully_native = False if fully_native is None else fully_native
        super().__init__(name, logger=logger, **kwargs)
        # Created as lossless always. use set_lossy() API to start simulating lossy parts of op.
        self.fully_native = default_fully_native
        self._input_shapes = None
        self._output_shapes = None
        # in case the op expects the input to be list in all cases (e.g. concat)
        self._force_input_list = False

        self.post_action = None
        self.debug_mode = False

        self.internal_encoding_enabled = True
        self.internal_decoding_enabled = True
        self.quant_inputs_enabled = True
        self.weight_lossy_elements: BaseWeightLossyElements = BaseWeightLossyElements()
        self._ignore_hw_limitation_assertion = None
        self._init_io_encoding()
        self._init_io_lossy_elements()
        self._init_stats_collection()

        self._bit_exact = bit_exact

        self.encoding_const = False

        self.FLOAT_TYPE_NP = np.float32
        self.INT_TYPE_NP = np.int32
        self.set_type_emulation()

        hsim_wrapper = HSimWrapper()
        hsim_wrapper.load()
        self._hsim = hsim_wrapper.hsim

    def set_type_emulation(self, data_type=EmulationType.REGULAR):
        if data_type == EmulationType.REGULAR:
            self.FLOAT_TYPE_TF = tf.float32
            self.INT_TYPE_TF = tf.int32
        elif data_type == EmulationType.DOUBLE:
            self.FLOAT_TYPE_TF = tf.float64
            self.INT_TYPE_TF = tf.int64
        else:
            raise AccelerasInitializationError("not a correct type")

    @property
    def bit_exact(self):
        return self._bit_exact

    @bit_exact.setter
    def bit_exact(self, value):
        self._bit_exact = value

    # region Properties
    @property
    def num_of_channels(self):
        try:
            inputs_shape = self.input_shapes[0]
        except AttributeError:
            # if the input shape is still not initialized return 1
            return 1
        return inputs_shape[-1]

    @property
    def input_shapes_is_valid(self):
        return self._input_shapes is not None

    @property
    def input_shapes(self):
        if self._input_shapes is None:
            raise AttributeError(
                f"The layer {self.full_name} has never been called and thus has no defined input shapes."
            )
        return self._input_shapes

    @property
    def input_shape(self):
        if self.num_inputs != 1:
            raise AttributeError(
                f"The layer has multiple inputs. "
                f'Hence the notion of "input_shape" is '
                f"ill-defined for the layer - {self.full_name}",
            )
        return self.input_shapes[0]

    @property
    def output_shapes_is_valid(self):
        return self._output_shapes is not None

    @property
    def output_shapes(self):
        if self._output_shapes is None:
            raise AttributeError(
                f"The layer {self.full_name} has never been called and thus has no defined output shapes."
            )
        return self._output_shapes

    @property
    def output_shape(self):
        if self.num_outputs != 1:
            raise AttributeError(
                f"The layer has multiple outputs. "
                f'Hence the notion of "output_shape" is '
                f"ill-defined for the layer - {self.full_name}",
            )
        return self.output_shapes[0]

    @property
    def input_lossy_element(self):
        if self.num_inputs != 1:
            raise AttributeError(
                f"The layer has multiple outputs. "
                f'Hence the notion of "input_lossy_element" is '
                f"ill-defined for the layer - {self.full_name}",
            )
        return self.input_lossy_elements[0]

    @property
    def output_lossy_element(self):
        if self.num_outputs != 1:
            raise AttributeError(
                f"The layer has multiple outputs. "
                f'Hence the notion of "output_lossy_element" is '
                f"ill-defined for the layer - {self.full_name}",
            )
        return self.output_lossy_elements[0]

    @property
    def input_scale(self):
        if self.num_inputs != 1:
            raise AttributeError(
                f"The layer has multiple outputs. "
                f'Hence the notion of "input_scale" is '
                f"ill-defined for the layer - {self.full_name}",
            )
        return self.input_scales[0]

    @input_scale.setter
    def input_scale(self, value):
        if self.num_inputs != 1:
            raise AttributeError(
                f"The layer has multiple outputs. "
                f'Hence the notion of "input_scale" is '
                f"ill-defined for the layer - {self.full_name}",
            )
        self.input_scales[0] = value

    @property
    def output_scale(self):
        if self.num_outputs != 1:
            raise AttributeError(
                f"The layer has multiple outputs. "
                f'Hence the notion of "output_scale" is '
                f"ill-defined for the layer - {self.full_name}",
            )
        return self.output_scales[0]

    @output_scale.setter
    def output_scale(self, value):
        if self.num_outputs != 1:
            raise AttributeError(
                f"The layer has multiple outputs. "
                f'Hence the notion of "output_scale" is '
                f"ill-defined for the layer - {self.full_name}",
            )
        self.output_scales[0] = value

    @property
    def input_zero_point(self):
        if self.num_inputs != 1:
            raise AttributeError(
                f"The layer has multiple outputs. "
                f'Hence the notion of "input_zero_point" is '
                f"ill-defined for the layer - {self.full_name}",
            )
        return self.input_zero_points[0]

    @input_zero_point.setter
    def input_zero_point(self, value):
        if self.num_inputs != 1:
            raise AttributeError(
                f"The layer has multiple outputs. "
                f'Hence the notion of "input_zero_point" is '
                f"ill-defined for the layer - {self.full_name}",
            )
        self.input_zero_points[0] = value

    @property
    def output_zero_point(self):
        if self.num_outputs != 1:
            raise AttributeError(
                f"The layer has multiple outputs. "
                f'Hence the notion of "output_zero_point" is '
                f"ill-defined for the layer - {self.full_name}",
            )
        return self.output_zero_points[0]

    @output_zero_point.setter
    def output_zero_point(self, value):
        if self.num_outputs != 1:
            raise AttributeError(
                f"The layer has multiple outputs. "
                f'Hence the notion of "output_zero_point" is '
                f"ill-defined for the layer - {self.full_name}",
            )
        self.output_zero_points[0] = value

    @property
    @abstractmethod
    def num_inputs(self) -> int:
        pass

    @property
    @abstractmethod
    def num_outputs(self) -> int:
        pass

    @property
    def bit_exact_supported(self) -> bool:
        return False

    @property
    def bit_exact_emulation_supported(self) -> bool:
        return self.bit_exact_supported and not (self.fully_native)

    # endregion

    @abstractmethod
    def create_weight_quant_element(self, **kwargs):
        pass

    @abstractmethod
    def call_native(self, inputs, training=None, **kawrgs):
        pass

    @abstractmethod
    def call_hw_sim(self, inputs, **kwargs):
        """
        [Required]
        Invoked by call() for any run except fully-native and bit_exact, enclosed by (de)numerization and I/O bit reduction.
        Implement in subclasses with the actual numeric operation, as it looks in "numerized space",
           using the final numeric weights (if any) as computed in infer_encodings()
        """

    def call_bit_exact(self, inputs, **kwargs):
        """
        Invoked by call() for bit_exact mode.
        Implement in subclasses with the actual numeric operation that is bit exact to hw
        """
        # TODO Will be overided in subcalss if needed - (bit_exact_supported) will be changed to abstract
        inputs = [
            tf.cast(inp_num_lossy, self.FLOAT_TYPE_TF) for inp_num_lossy in inputs
        ]  # TODO add this only when its done

        outputs_num = self.call_hw_sim(inputs, **kwargs)

        outputs_num = outputs_num if isinstance(outputs_num, list) else [outputs_num]
        outputs_num_lossy = self._quantize(outputs_num, self.output_lossy_elements, **kwargs)

        return outputs_num_lossy

    def enforce_encoding(self, *args, **kwargs):
        """
        [Optional] Perform most "compile-time" part of the encodings computation:
              1. Compute all dependent encodings (aka, scales&zp) from independent ones, imposed from above
              2. possibly, also compute the final numeric weights, including weights numerization and "bit-reduction"
                 Mostly however, the #2 is performed separately in a property method
             Ideally, use fully differentiable tensorflow so everything is trainable,
             the method will be invoked in call() to that end to be part of the graph. See also:
        https://hailotech.atlassian.net/wiki/spaces/ML/pages/943259731/Vectorized+Scales
        https://hailotechcom-my.sharepoint.com/:w:/g/personal/alexf_hailo_ai/EfzOd6bw2B9BtUmzdqC9oa4B92CKMI-HKXXFgPPWjbJzwQ?e=2cTysg
        """

    def import_independent_params(self, params):
        pass

    def is_differentiable(self) -> bool:
        """
        [Required] Indicates whether the atomic op can be backpropagated through

        Returns
            boolean, True if the op can be backpropagated through

        """
        return True

    def enable_lossy(self, **kwargs):
        """
        optionally override in subclasses where desired default lossy elements are complex,
        and can't be simply specified by the LossyElement class.
        """
        lossy_elements = itertools.chain(
            self.input_lossy_elements,
            self.output_lossy_elements,
            self.weight_lossy_elements.__dict__.values(),
        )
        for lossy_elem in lossy_elements:
            lossy_elem.enable()

    def disable_lossy(self, **kwargs):
        lossy_elements = itertools.chain(
            self.input_lossy_elements,
            self.output_lossy_elements,
            self.weight_lossy_elements.__dict__.values(),
        )
        for lossy_elem in lossy_elements:
            lossy_elem.disable()

    def set_ignore_hw_limitation_assertion(self, ignore_hw_limitation_assertion):
        if ignore_hw_limitation_assertion is not None:
            self._ignore_hw_limitation_assertion = ignore_hw_limitation_assertion

    # region IO lossy & encoding

    def _init_io_encoding(self):
        # Set default I/O numerization to trivial (identity)
        self.input_scales = [np.float32(1.0)] * self.num_inputs
        self.input_zero_points = [np.float32(0.0)] * self.num_inputs
        self.output_scales = [np.float32(1.0)] * self.num_outputs
        self.output_zero_points = [np.float32(0.0)] * self.num_outputs

    def _init_io_lossy_elements(self):
        self.input_lossy_elements: List[BaseLossyElement] = [
            IdentityElement(name=f"{self.full_name}/ie:input_{i}") for i in range(self.num_inputs)
        ]
        self.output_lossy_elements: List[BaseLossyElement] = [
            IdentityElement(name=f"{self.full_name}/ie:output_{i}") for i in range(self.num_outputs)
        ]

    def set_input_lossy_element(self, element, index=0):
        self.input_lossy_elements[index] = element

    def set_output_lossy_element(self, element, index=0):
        self.output_lossy_elements[index] = element

    def _encode_inputs(self, inputs):
        return [
            tf.cast(offs, self.FLOAT_TYPE_TF) + tf.cast(inp, self.FLOAT_TYPE_TF) / tf.cast(sc, self.FLOAT_TYPE_TF)
            for inp, sc, offs in zip(inputs, self.input_scales, self.input_zero_points)
        ]

    def _decode_output(self, outputs):
        return [
            (tf.cast(out, self.FLOAT_TYPE_TF) - tf.cast(offs, self.FLOAT_TYPE_TF)) * tf.cast(sc, self.FLOAT_TYPE_TF)
            for out, sc, offs in zip(outputs, self.output_scales, self.output_zero_points)
        ]

    def _encode_outputs(self, outputs):
        return [
            tf.cast(offs, self.FLOAT_TYPE_TF) + tf.cast(out, self.FLOAT_TYPE_TF) / tf.cast(sc, self.FLOAT_TYPE_TF)
            for out, sc, offs in zip(outputs, self.output_scales, self.output_zero_points)
        ]

    def _decode_inputs(self, inputs):
        return [
            (tf.cast(inp, self.FLOAT_TYPE_TF) - tf.cast(offs, self.FLOAT_TYPE_TF)) * tf.cast(sc, self.FLOAT_TYPE_TF)
            for inp, sc, offs in zip(inputs, self.input_scales, self.input_zero_points)
        ]

    def _quantize(self, values, lossy_elements, training=False, **kwargs):
        return [f(tf.cast(x, self.FLOAT_TYPE_TF), training=training) for x, f in zip(values, lossy_elements)]

    def export_independent_params(self) -> dict:
        """
        Export independent params of the op
        Independent params are all the tensors that should either be variables or constant numpy values

        """
        return {}

    def export_quant_weights(self) -> dict:
        """
        Export 'dependent' params, mostly apply the scale to the weights
        Params that are used directly in the compiler / numeric simulation, but aren't represented as variables
        """
        return {}

    def export_hw_params(self) -> dict:
        """
        Export only the params that are needed for the hardware
        """
        return {}

    def export_flow_state(self) -> AtomicOpState:
        """
        export the flow parameters of the atomic ops.
        Params that are used to modify the run, e.g is it fully native.
        """
        if self.fully_native:
            state = FlowState.FULLY_NATIVE
        elif self.bit_exact:
            state = FlowState.BIT_EXACT
        else:
            state = FlowState.NUMERIC

        input_lossy_dict = self._gen_flow_dict(self.input_lossy_elements)
        output_lossy_dict = self._gen_flow_dict(self.output_lossy_elements)
        weights_lossy_dict = self._gen_flow_dict(self.weight_lossy_elements.__dict__.values())

        return AtomicOpState(
            full_name=self.full_name,
            aops_class_type=self.__class__.__name__,
            status=state,
            input_lossy_elements=input_lossy_dict,
            output_lossy_elements=output_lossy_dict,
            weight_lossy_elements=weights_lossy_dict,
            internal_encoding_enabled=self.internal_encoding_enabled,
            internal_decoding_enabled=self.internal_decoding_enabled,
            quant_inputs_enabled=self.quant_inputs_enabled,
        )

    def import_flow_state(self, atomic_state: AtomicOpState):
        """
        import the flow parameters of the atomic ops.
        instantiating LossyElements to modify the flow
        """
        if self.full_name != atomic_state.full_name:
            raise AccelerasInitializationError(
                f"while importing flow states, names didn't match. current {self.full_name} and attempted import {atomic_state.full_name}"
            )
        status = FlowState(atomic_state.status)
        self.fully_native = status == FlowState.FULLY_NATIVE
        self.bit_exact = status == FlowState.BIT_EXACT

        lossy_elements_mapping = [
            (self.input_lossy_elements, atomic_state.input_lossy_elements),
            (self.output_lossy_elements, atomic_state.output_lossy_elements),
            (self.weight_lossy_elements.__dict__.values(), atomic_state.weight_lossy_elements),
        ]

        self.internal_encoding_enabled = atomic_state.internal_encoding_enabled
        self.internal_decoding_enabled = atomic_state.internal_decoding_enabled
        self.quant_inputs_enabled = atomic_state.quant_inputs_enabled
        for lossy_elements, atomic_lossy_dict in lossy_elements_mapping:
            self._update_lossy_element_flow_state(lossy_elements, atomic_lossy_dict)

    @staticmethod
    def _gen_flow_dict(lossy_element_list: List[BaseLossyElement]):
        return {lossy.full_name: lossy.export_flow_state() for lossy in lossy_element_list}

    @staticmethod
    def _update_lossy_element_flow_state(
        lossy_element_list: List[BaseLossyElement], lossy_dict_to_update_with: Dict[str, BaseLossyElement]
    ):
        for lossy in lossy_element_list:
            lossy.import_flow_state(lossy_dict_to_update_with[lossy.full_name])

    @property
    def is_weights_lossy(self):
        return all(
            not lossy_elem.is_lossless
            for lossy_elem in self.weight_lossy_elements.__dict__.values()
            if isinstance(lossy_elem, BaseQuantElement)
        )

    def export_quant(self, include_shared_weights=True) -> dict:
        """
        Export all the quantized information of the model, both dependent and independent tensors.
        Exported all data required for compilation, simulation, and reconstruction of the op.
        The io encoding of the op is always exported

        This function adds the op name as a prefix to the dictionary keys. It will look as follows:
            <opname>/<key_name>
        The opname is the basename of the op e.g. 'conv_op'
        """
        export_params = {}
        export_params.update(self.export_independent_params())
        export_params.update(self.export_quant_weights())
        for index in range(self.num_inputs):
            export_params.update(self.export_input_encoding(index))
        export_params.update(self.export_output_encoding())
        if not include_shared_weights:
            export_params = self._remove_shared_params(export_params)
        opname = os.path.basename(self.full_name)
        return {f"{opname}/{k}": v for k, v in export_params.items()}

    def _remove_shared_params(self, export_params: dict) -> dict:
        return export_params

    def import_quant(self, params: dict):
        """
        Imports the quantization info and scales to the op, expects the value returned by export_quant.
        Imports IO encoding info and independent params.

        the keys of the dict should look like '<opname>/<key_name>', read export_quant for additional info
        """
        params = copy.deepcopy(params)
        op_name = os.path.basename(self.full_name)
        op_params = {k[len(op_name) + 1 :]: v for k, v in params.items() if k.startswith(op_name)}
        self.import_independent_params(op_params)
        for index in range(self.num_inputs):
            self.import_input_encoding(op_params, index)
        for index in range(self.num_outputs):
            self.import_output_encoding(op_params, index)

    def calc_output_encoding_candidates(
        self,
        output_index: int,
        forced_range: Optional[Tuple[int, int]] = None,
        output_lossy_external: Optional[BaseQuantElement] = None,
        factor=None,
        translation_config=None,
        split_precision_zp=None,
    ):
        """
        output_index: int -
        forced_range =  default None tuple of the min and max of the output
        output_lossy_external: default None -if not is of kind BaseQuantElement
        factor: default None - vector of the lenth pf the output channels

        return output_scale, output_zp
        """
        output_lossy_element = (
            self.output_lossy_elements[output_index] if output_lossy_external is None else output_lossy_external
        )
        if translation_config is None or translation_config.activation_symmetric_range == LayerFeaturePolicy.allowed:
            activation_symmetric_range = False if output_lossy_external is None else True
        else:
            activation_symmetric_range = translation_config.activation_symmetric_range == LayerFeaturePolicy.enabled

        output_stats = self.get_output_stats(output_index)
        if forced_range is not None:
            min_scale = np.min(self.output_scales[output_index])
            min_scale *= 1 / np.max(self.output_scales[output_index])
            scaled_forced_range = min_scale * np.array(forced_range, min_scale.dtype)
            forced_min = tf.ones_like(output_stats.min) * scaled_forced_range[0]
            forced_max = tf.ones_like(output_stats.max) * scaled_forced_range[1]
            limvals = (forced_min, forced_max)
        else:
            limvals = (output_stats.min, output_stats.max)

        if factor is not None:
            limvals = limvals * factor

        output_scale, output_zp = update_scale(
            self.output_scales[output_index],
            limvals,
            output_lossy_element,
            self.full_name,
            self._logger,
            activation_symmetric_range,
            split_precision_zp=split_precision_zp,
        )
        if output_scale.shape == ():
            # get the shapes from the statistics
            # TODO: make sure this property behaves as expected
            output_channels = self.output_shape[-1]
            output_scale = np.repeat(output_scale, output_channels)

        return output_scale, output_zp

    def create_output_encoding_candidates(
        self,
        output_index,
        forced_range=None,
        output_lossy_external=None,
        translation_config=None,
        split_precision_zp=None,
    ):
        """
        Create baseline encoding-params candidates from internally collected statistics;
        base class provides implementation for I/O (not weights)
        """
        if not self._has_output_stats() and forced_range is None:
            return
        if self.stats_collection_state != StatsState.COMPLETE:
            raise AccelerasNumerizationError("Can't translate param before stats collection")

        output_scale, output_zp = self.calc_output_encoding_candidates(
            output_index,
            forced_range=forced_range,
            output_lossy_external=output_lossy_external,
            translation_config=translation_config,
            split_precision_zp=split_precision_zp,
        )
        self.output_scales[output_index] = output_scale
        self.output_zero_points[output_index] = output_zp

    def import_output_encoding(self, params, output_index=0):
        """
        Imports the output scale, and zero point of the op.
        The output bits are not imported but are verified based on the value in the lossy element

        Imports the info to a specific output index (currently only 0 is supported)
        """
        if not isinstance(self.output_lossy_elements[output_index], BaseQuantElement):
            raise AccelerasPrematureQuantOperation("import_output_encoding", self.full_name)
        out_bits = self.output_lossy_elements[output_index].bits
        imported_out_bits = params[f"output_bits:{output_index}"]
        if out_bits != imported_out_bits:
            raise AccelerasImportParamConfigMismatch(
                f"output_bits:{output_index}",
                out_bits,
                imported_out_bits,
                self.full_name,
            )
        self.output_scales[output_index] = params[f"output_scale:{output_index}"]
        self.output_zero_points[output_index] = params[f"output_zero_point:{output_index}"]

    def export_output_encoding(self, output_index=0):
        """
        Exports the output scale, zero point, and output bits of the op.

        Exports the info for a specific output index (currently only 0 is supported)
        """
        if not isinstance(self.output_lossy_elements[output_index], BaseQuantElement):
            raise AccelerasPrematureQuantOperation("export_output_encoding", self.full_name)
        return {
            f"output_scale:{output_index}": np.array(self.output_scales[output_index], np.float32),
            f"output_zero_point:{output_index}": np.array(self.output_zero_points[output_index], np.float32),
            f"output_bits:{output_index}": np.array(self.output_lossy_elements[output_index].bits, np.float32),
        }

    def calc_input_encoding_candidates(
        self,
        input_index: int,
        input_lossy_external: Optional[BaseQuantElement] = None,
        factor=None,
        translation_config=None,
        split_precision_zp=None,
    ):
        """
        input_index: int -
        input_lossy_external: default None -if not is of kind BaseQuantElement
        factor: default None - vector of the lenth pf the input channels

        return inp_scale, inp_zp
        """
        input_lossy_elements = (
            self.input_lossy_elements[input_index] if input_lossy_external is None else input_lossy_external
        )
        if translation_config is None or translation_config.activation_symmetric_range == LayerFeaturePolicy.allowed:
            activation_symmetric_range = False if input_lossy_external is None else True
        else:
            activation_symmetric_range = translation_config.activation_symmetric_range == LayerFeaturePolicy.enabled
        input_stats = self.get_input_stats(input_index)
        input_scale = self.input_scales[input_index]

        limvals = (input_stats.min, input_stats.max)

        if factor is not None:
            limvals = limvals * factor

        inp_scale, inp_zp = update_scale(
            input_scale,
            limvals,
            input_lossy_elements,
            self.full_name,
            self._logger,
            activation_symmetric_range=activation_symmetric_range,
            split_precision_zp=split_precision_zp,
        )

        if inp_scale.shape == ():
            # get the shapes from the statistics
            input_channels = self.input_shapes[input_index][-1]
            inp_scale = np.repeat(inp_scale, input_channels)

        return inp_scale, inp_zp

    def create_input_encoding_candidates(
        self, input_index, input_lossy_external=None, factor=None, translation_config=None, split_precision_zp=None
    ):
        """
        Create baseline encoding-params candidates from internally collected statistics;
        base class provides implementation for I/O (not weights)
        """
        if not self._has_inputs_stats():
            return

        if self.stats_collection_state != StatsState.COMPLETE:
            raise AccelerasNumerizationError("Can't translate param before stats collection")
        inp_scale, inp_zp = self.calc_input_encoding_candidates(
            input_index,
            input_lossy_external=input_lossy_external,
            factor=factor,
            translation_config=translation_config,
            split_precision_zp=split_precision_zp,
        )
        self.input_zero_points[input_index] = inp_zp
        self.input_scales[input_index] = inp_scale

    def vectorize_scales(self):
        """
        Vectorize the io scales of the layers based on input / output shapes

        The shapes exist after the model has been inferred
        """
        for input_index, inp_scale in enumerate(self.input_scales):
            if inp_scale.shape == ():
                # get the shapes from the statistics
                input_channels = self.input_shapes[input_index][-1]
                inp_scale = np.repeat(inp_scale, input_channels)
                self.input_scales[input_index] = inp_scale

        for output_index, out_scale in enumerate(self.output_scales):
            if out_scale.shape == ():
                # get the shapes from the statistics
                output_channels = self.output_shapes[output_index][-1]
                out_scale = np.repeat(out_scale, output_channels)
                self.output_scales[output_index] = out_scale

    @staticmethod
    def _get_dim(input_shape):
        dim = 0
        val = input_shape
        for i in range(100):
            if isinstance(val, int) or val is None:
                return dim
            else:
                dim += 1
                val = val[0]
        raise AccelerasValueError("input shape values should have ints")

    def compute_output_shape(self, input_shape):
        if isinstance(input_shape, dict):
            input_shape = [v for v in input_shape.values()]
        dims = self._get_dim(input_shape)
        if dims > 1 and self.num_inputs == 1:
            input_shape = input_shape[0]
        if self.num_inputs == 1:
            self._input_shapes = [input_shape]
        else:
            self._input_shapes = input_shape
        shapes = self._compute_output_shape(input_shape)
        if self.num_outputs > 1:
            for shape in shapes:
                if tf.reduce_any(tf.less(shape[1:], 0)):
                    raise InvalidInputShape(
                        f"Input shapes {shape} must be positive in {self.full_name}", self.full_name
                    )
            self._output_shapes = [tf.TensorShape(shape) for shape in shapes]
        else:
            if tf.executing_eagerly():
                if tf.reduce_any(tf.less(shapes[1:], 0)):
                    raise InvalidInputShape(
                        f"Input shapes {shapes} must be positive in {self.full_name}", self.full_name
                    )
            else:
                # Do not run these lines in graph mode (no need to test?).
                # Note: Another option is to keep the if tf.reduce_any(...) condition and run the test
                #      in eager mode only for stats_collerctor.py::collect_stats(..., run_eagerly=True).
                pass
            self._output_shapes = [tf.TensorShape(shapes)]
        return shapes

    def _compute_output_shape(self, input_shape):
        bs = input_shape[0][0] if self.num_inputs > 1 else input_shape[0]
        if self.num_inputs > 1:
            random_samples = []
            for shape in input_shape:
                if bs is None:
                    random_samples.append(np.random.random([1, *shape[1:]]))
                else:
                    random_samples.append(np.random.random(shape))
        elif bs is None:
            random_samples = np.random.random([1, *input_shape[1:]])
        else:
            random_samples = np.random.random(input_shape)
        res = self(random_samples, training=False, fully_native=True, skip_stats=True)
        if bs is None:
            if self.num_inputs == 1:
                self._input_shapes = [input_shape]
            else:
                self._input_shapes = input_shape
        if isinstance(res, list):
            res_shape = [out.shape.as_list() for out in res]
            if bs is None:
                for shape in res_shape:
                    shape[0] = None
        else:
            res_shape = res.shape.as_list()
            if bs is None:
                res_shape[0] = None
        return res_shape

    def import_input_encoding(self, params, input_index):
        """
        Imports the input scale, and zero point of the op.
        The input bits are not imported but are verified based on the value in the lossy element

        Imports the info to a specific input index of the op
        """
        if not isinstance(self.input_lossy_elements[input_index], BaseQuantElement):
            raise AccelerasPrematureQuantOperation("import_input_encoding", self.full_name)
        inp_bits = self.input_lossy_elements[input_index].bits
        imported_inp_bits = params[f"input_bits:{input_index}"]
        if inp_bits != imported_inp_bits:
            raise AccelerasImportParamConfigMismatch(
                f"input_bits:{input_index}",
                inp_bits,
                imported_inp_bits,
                self.full_name,
            )
        self.input_scales[input_index] = params[f"input_scale:{input_index}"]
        self.input_zero_points[input_index] = params[f"input_zero_point:{input_index}"]

    def export_input_encoding(self, input_index):
        """
        Exports the input scale, zero point, and input bits of the op.

        Exports the info to a specific input index of the op
        """
        if not isinstance(self.input_lossy_elements[input_index], BaseQuantElement):
            raise AccelerasPrematureQuantOperation("export_input_encoding", self.full_name)
        return {
            f"input_scale:{input_index}": np.array(self.input_scales[input_index], np.float32),
            f"input_zero_point:{input_index}": np.array(self.input_zero_points[input_index], np.float32),
            f"input_bits:{input_index}": np.array(self.input_lossy_elements[input_index].bits, np.float32),
        }

    def enable_internal_encoding(self):
        self.internal_encoding_enabled = True
        self.internal_decoding_enabled = True
        self.quant_inputs_enabled = True

    def disable_internal_encoding(
        self, encode_inputs=None, decode_outputs=None, quant_inputs=None, *, export_model_state=False
    ):
        def set_value(value):
            return False if value is None else value

        state_diff = {}

        encode_inputs = set_value(encode_inputs)
        decode_outputs = set_value(decode_outputs)
        quant_inputs = set_value(quant_inputs)
        if self.internal_encoding_enabled != encode_inputs:
            state_diff["internal_encoding_enabled"] = encode_inputs
        if self.internal_decoding_enabled != decode_outputs:
            state_diff["internal_decoding_enabled"] = decode_outputs
        if self.quant_inputs_enabled != quant_inputs:
            state_diff["quant_inputs_enabled"] = quant_inputs

        if export_model_state:
            return {self.full_name: state_diff} if state_diff else None
        else:
            self.internal_encoding_enabled = encode_inputs
            self.internal_decoding_enabled = decode_outputs
            self.quant_inputs_enabled = quant_inputs

    # endregion

    # region Stats collection

    def _init_stats_collection(self):
        self.stats_collection_state = StatsState.BLANK
        self.stats_managers = dict()
        self.stats_cfg = tuple()
        self.output_limvals = dict()
        self._collect_inputs = False
        self._collect_output = False

    def start_stats_collection(self, stats_cfg: tuple = BasicTypeTuple, collect_inputs=True, collect_output=True):
        """
        stats_cfg - user API to control the what&how of stats collection
        (rough now, evolve this by need..)
        """
        self.stats_cfg = tuple(set(self.stats_cfg) | set(stats_cfg))
        if collect_inputs:
            self._collect_inputs = True
        if collect_output:
            self._collect_output = True
        self._initialize_statistics_managers()
        self._reset_statistics_managers()
        self.stats_collection_state = StatsState.RUNNING

    def stop_stats_collection(self):
        self.stats_collection_state = StatsState.COMPLETE
        for key in self.stats_managers:
            self.stats_managers[key].finalize()
        self._collect_inputs = False
        self._collect_output = False

    def import_input_stats(self, stats, index):
        self.stats_cfg = tuple(TypeStats(key) for key in stats.keys())
        self.stats_managers[f"inputs_{index}"] = ImportedStats(stats)
        self.stats_collection_state = StatsState.COMPLETE

    def import_output_stats(self, stats, index):
        self.stats_cfg = tuple(TypeStats(key) for key in stats.keys())
        self.stats_managers[f"outputs_{index}"] = ImportedStats(stats)
        self.stats_collection_state = StatsState.COMPLETE

    def handle_update_stats(self, inputs, outputs):
        """
        handles all the stats collection - Assumes running in eager mode
        """
        if self.stats_collection_state in [StatsState.RUNNING]:
            # Assert the valHandle the are self.STATS_TYPE_FLOAT
            inputs = [tf.cast(inp, STATS_TYPE_FLOAT) for inp in inputs]
            outputs = [tf.cast(out, STATS_TYPE_FLOAT) for out in outputs]
            self._update_statistics_managers(inputs, outputs)

    def _initialize_statistics_managers(self):
        """
        this function initialize all the stats_manager for all inputs and outputs in the atomic base. In the
        initialization stage each statistic manager will initiate all the base_statistic in order to collect stats.
        """
        self.stats_managers = dict()
        # initialize inputs
        if self._collect_inputs:
            for index, input_shape in enumerate(self.input_shapes):
                self._initialize_stats_manager(input_shape, f"inputs_{index}")
        # initialize outputs
        if self._collect_output:
            for index, output_shape in enumerate(self.output_shapes):
                self._initialize_stats_manager(output_shape, f"outputs_{index}")

    def _get_stats_axes(self, data_shape):
        data_axes = np.arange(len(data_shape))
        return np.delete(data_axes, -1)

    def _initialize_stats_manager(self, data_shape, key):
        """
        will initialize a specific stats_manager for a tensor if it was not created already.
        """
        if key not in self.stats_managers:
            axes = self._get_stats_axes(data_shape)
            metric_length = np.prod(np.delete(data_shape, axes))
            self.stats_managers[key] = StatsManager(axis_to_accumulate=axes, metric_length=metric_length)

    def _reset_statistics_managers(self):
        """
        reset all the statistics managers - in the stats_cfg it indicated which specific statistics in statistic tuple
        needs to be collected.
        For stats_base that have specific params (like Histogram) the params will be set in reset.

        """
        for key, stats_manager in self.stats_managers.items():
            if key.startswith("outputs_"):
                index = int(key[len("outputs_") :])
                stats_manager.reset(self.stats_cfg, hist_range=self.output_limvals.get(index))
            else:
                # TODO for now we dont support histogram on inputs.
                stats_cfg = tuple(x for x in self.stats_cfg if x != TypeStats.HISTOGRAM)
                if len(stats_cfg) > 0:
                    stats_manager.reset(stats_cfg)

    def _update_statistics_managers(self, inputs, outputs):
        """
        update all the statistics managers - each of which will update all of its statistic_base according to the
        stats_cfg it has.
        """
        if self._collect_inputs:
            for index in range(len(inputs)):
                self.stats_managers[f"inputs_{index}"].update(inputs[index])
        if self._collect_output:
            for index in range(len(outputs)):
                self.stats_managers[f"outputs_{index}"].update(outputs[index])

    def _has_inputs_stats(self):
        input_keys = {f"inputs_{i}" for i in range(self.num_inputs)}
        existsing_keys = self.stats_managers.keys()
        return input_keys.issubset(existsing_keys)

    def _has_output_stats(self):
        output_keys = {f"outputs_{i}" for i in range(self.num_outputs)}
        existsing_keys = self.stats_managers.keys()
        return output_keys.issubset(existsing_keys)

    def set_output_limvals(self, output_index):
        """
        set output_limvals for the tensor

        """
        # TODO: remove when we fully move to acceleras. assumes single atomic op with single output
        self.output_limvals[output_index] = self.get_output_limvals(output_index)

    def get_input_stats(self, input_index) -> Statistics:
        """Returns a list all basic statistics of the basic_atomic_op inputs"""
        return self.stats_managers[f"inputs_{input_index}"].get()

    def get_output_stats(self, output_index) -> Statistics:
        """Returns a list all basic statistics of the basic_atomic_op outputs"""
        return self.stats_managers[f"outputs_{output_index}"].get()

    def get_input_limvals(self, input_index) -> tuple:
        """Returns a list of the min-max values for  the basic_atomic_op inputs"""
        stats = self.get_input_stats(input_index)
        return np.min(stats.min), np.max(stats.max)

    def get_output_limvals(self, output_index) -> tuple:
        """Returns a tuple of the min-max values for  the basic_atomic_op outputs"""
        stats = self.get_output_stats(output_index)
        return np.min(stats.min), np.max(stats.max)

    def get_group_input_limvals(self, input_index: int, groups: int = 1) -> Tuple[np.ndarray, np.ndarray]:
        stats = self.get_input_stats(input_index)
        return np.min(stats.min.reshape(groups, -1), axis=1), np.max(stats.max.reshape(groups, -1), axis=1)

    def get_group_output_limvals(self, output_index: int, groups: int = 1) -> Tuple[np.ndarray, np.ndarray]:
        stats = self.get_output_stats(output_index)
        return np.min(stats.min.reshape(groups, -1), axis=1), np.max(stats.max.reshape(groups, -1), axis=1)

    # endregion
    # region Create Hw Params
    def create_hw_params(self, *args, **kwargs):
        """
        Algorithmic part of encodings creation;
        assumes basic individual calib-driven candidates already exist.
        Example: for a passthru layer, nothing to do here (I/O zp&scale finalized)
        """

    def _io_numerization_partial_update(
        self,
        input_scales=(None,),
        input_zero_points=(None,),
        output_scale=None,
        output_zero_point=None,
    ):
        """
        TODO consider calling this method directly from subclasses instead of overloading
        """
        self.input_scales = [
            new_scale if new_scale is not None else scale for scale, new_scale in zip(self.input_scales, input_scales)
        ]
        self.input_zero_points = [
            new_zp if new_zp is not None else zp for zp, new_zp in zip(self.input_zero_points, input_zero_points)
        ]
        # used for legacy import, we should have only one output scale
        self.output_scales[0] = self.output_scales[0] if output_scale is None else output_scale
        self.output_zero_points[0] = self.output_zero_points[0] if output_zero_point is None else output_zero_point

    def build(self, input_shape):
        if isinstance(input_shape, dict):
            input_shape = [v for v in input_shape.values()]
        dims = self._get_dim(input_shape)
        if dims > 1 and self.num_inputs == 1:
            input_shape = input_shape[0]
        self._build(input_shape)
        self._build_input_shape = input_shape
        self.built = True

    def _build(self, input_shape):
        pass

    def call(
        self,
        inputs,
        fully_native=None,
        encoding_tensors: dict = None,
        skip_stats=False,
        training=False,
        **kwargs,
    ):
        """
        Implementing the Keras call() API, with the help of subclass-specific :
        call_native(),
        call_hw_sim()
        call_bit_exact() - this will be in the future
        When encoding_tensors is given, update the atomic op encodings variables using the specified tensors.
        """
        inputs = inputs if isinstance(inputs, list) else [inputs]
        self._input_shapes = [inp.shape for inp in inputs]

        fully_native = self.fully_native if fully_native is None else fully_native

        if encoding_tensors is not None:
            self._tracker.locked = False
            self.update_encoding(encoding_tensors)
            self._tracker.locked = True

        if fully_native:
            outputs = self._native_run(inputs, training=training, **kwargs)
        elif self.bit_exact:
            outputs = self._bit_exact_run(inputs, training=training, **kwargs)
        else:
            outputs = self._numeric_run(inputs, training=training, **kwargs)

        if not skip_stats:
            self.handle_update_stats(inputs, outputs)

        if self.post_action is not None:
            outputs = self.post_action(inputs, outputs, training=training)
        self._output_shapes = [out.shape for out in outputs]
        return outputs if self.num_outputs > 1 else outputs[0]

    def _to_numeric(self, inputs, encode_inputs=None, should_quant=None, training=False):
        """
        if needed - transfer inputs from Native_space ==>> Numeric_space
        """
        encode_inputs = self.internal_encoding_enabled if encode_inputs is None else encode_inputs
        should_quant = self.quant_inputs_enabled if should_quant is None else should_quant

        inputs_num = self._encode_inputs(inputs) if encode_inputs else inputs

        inputs_num_lossy = (
            self._quantize(inputs_num, self.input_lossy_elements, training=training) if should_quant else inputs_num
        )

        if self.debug_mode:
            self._inputs_sample = [v.numpy() for v in inputs]  # 1
            self._inputs_num_sample = [v.numpy() for v in inputs_num]  # 2
            self._inputs_num_lossy_sample = [v.numpy() for v in inputs_num_lossy]  # 3
        return inputs_num_lossy

    def _from_numeric(self, outputs_num, decode_outputs=None, should_quant=True, training=False):
        """
        if needed - transfer output from Numeric_space  ==>> Native_space
        """
        decode_outputs = self.internal_decoding_enabled if decode_outputs is None else decode_outputs

        outputs_num_lossy = (
            self._quantize(outputs_num, self.output_lossy_elements, training=training) if should_quant else outputs_num
        )
        outputs = self._decode_output(outputs_num_lossy) if decode_outputs else outputs_num_lossy

        if self.debug_mode:
            self._output_num_sample = outputs_num  # 4
            self._output_num_lossy_sample = outputs_num_lossy  # 5
            self._output_sample = outputs  # 6
        return outputs

    # region call helper functions

    def _bit_exact_run(self, inputs, training, **kwargs):
        """
        The bit_exact run results are bit exact to the Hailo hardware output without running on an actual device.
        runs call_bit_exact() which is implemented in op

        """
        inputs_num_lossy = self._to_numeric(inputs, training=training)
        inputs_num_lossy = [tf.cast(inp_num_lossy, self.INT_TYPE_TF) for inp_num_lossy in inputs_num_lossy]
        outputs_bit_exact = self.call_bit_exact(inputs_num_lossy, training=training, **kwargs)
        outputs_bit_exact = outputs_bit_exact if isinstance(outputs_bit_exact, list) else [outputs_bit_exact]

        outputs = self._from_numeric(outputs_bit_exact, should_quant=False, training=training)  # dont qunat outputs

        return outputs

    def _numeric_run(self, inputs, training, **kwargs):
        """
        The numeric run results not hardware bit exact, but it's `hardware like` and the main emulation.

        The emulation Wrapped it with the common parts of:
            - input (output) numerization (denumerization) and
            - lossy ops (aka input/output lossy_elements) - on input and output;
            - lossy_elements weights (if any) is op-specific, implemented in the subclass.

        runs call_hw_sim() which is implemented in op
        """
        inputs = [tf.cast(inp, self.FLOAT_TYPE_TF) for inp in inputs]
        inputs_num_lossy = self._to_numeric(inputs, training=training)
        outputs_num = self.call_hw_sim(inputs_num_lossy, training=training, **kwargs)
        outputs_num = outputs_num if isinstance(outputs_num, list) else [outputs_num]

        outputs = self._from_numeric(outputs_num, should_quant=True, training=training)  # always qunat outputs

        return outputs

    def _native_run(self, inputs, training, **kwargs):
        """
        The native emulation runs the op inference with no loss as needed.
        runs call_native() which is implemented in op
        """
        native_input = [tf.cast(inp, self.FLOAT_TYPE_TF) for inp in inputs]
        outputs = self.call_native(native_input, training=training, **kwargs)
        outputs = outputs if isinstance(outputs, list) else [outputs]

        if self.debug_mode:
            native_inp = [tf.cast(v, self.FLOAT_TYPE_TF).numpy() for v in inputs]
            self._inputs_sample = native_inp  # 1
            self._inputs_num_sample = native_inp  # 2
            self._inputs_num_lossy_sample = native_inp  # 3
            self._output_num_sample = outputs  # 4
            self._output_num_lossy_sample = outputs  # 5
            self._output_sample = outputs  # 6

        return outputs

    # endregion

    # region Legacy export
    def export_output_params(self, output_scale=None):
        # export output params
        if output_scale is None:
            output_scale = self.output_scales[0]
        qp_out = (self.output_zero_points[0], output_scale)
        output_bits = self.output_lossy_elements[0].bits
        limvals_out = qp_to_limvals(qp_out, output_bits)

        return {"qp_out": np.float32(qp_out), "limvals_out": np.float32(limvals_out)}

    # endregion

    def input_scale_is_scalar(self, input_index=0):
        """
        return 1 if the input scale is scalar
        Args:
            input_index (int, optional): the index of the input scale
        """
        scale = tf.convert_to_tensor(self.input_scales[input_index])
        return len(scale.shape) == 0

    def output_scale_is_scalar(self, input_index=0):
        """
        return 1 if the input scale is scalar
        Args:
            input_index (int, optional): the index of the input scale
        """
        scale = tf.convert_to_tensor(self.output_scales[input_index])
        return len(scale.shape) == 0

    # region encoding flow

    def get_encoding_flow(self) -> EncodingFlowGraph:
        """
        return encoding flow graph with the op's encodings, and their respected constraints.
        """
        flow = EncodingFlowGraph()
        enc = EncodingSubOp(flow)

        self.define_encodings(flow)
        if self.encoding_const:
            self.define_const_constraints(enc)
        else:
            self.define_constraints(enc)

        return flow

    def define_encodings(self, flow: EncodingFlowGraph):
        """
        Define the encoding nodes of the atomic op.

        Encoding names should look like '{self.full_name}/<key_name>:<index>'. If an equivalent value exist in
        import_quant/export_quant functions, then key_name should match.

        e.g <layer_name>/conv_op/mac_shift:0 will be used to describe the pre_acc_shift of conv_op inside the
        layer called layer_name.

        Args:
            flow (EncodingFlowGraph): base encoding flow graph to add the atomic op's encodings to

        """
        for input_index, shape in enumerate(self.input_shapes):
            scale_initializer, zp_initializer = (
                TensorInitializer(self.input_scales[input_index]),
                TensorInitializer(self.input_zero_points[input_index]),
            )
            zp_shape = (
                self.input_zero_points[input_index].shape
                if hasattr(self.input_zero_points[input_index], "shape")
                else ()
            )
            flow.add_encoding(
                f"{self.full_name}/input_scale:{input_index}",
                EncodingType.Scale,
                scalar=False,
                shape=(shape[-1],),
                initializer=scale_initializer,
            )
            flow.add_encoding(
                f"{self.full_name}/input_zero_point:{input_index}",
                EncodingType.ZeroPoint,
                scalar=False,
                shape=zp_shape,
                initializer=zp_initializer,
                quant=True,
                quant_min=tf.float32.min,
                quant_max=tf.float32.max,
            )
        for output_index, shape in enumerate(self.output_shapes):
            scale_initializer, zp_initializer = (
                TensorInitializer(self.output_scales[output_index]),
                TensorInitializer(self.output_zero_points[output_index]),
            )
            zp_shape = (
                self.output_zero_points[output_index].shape
                if hasattr(self.output_zero_points[output_index], "shape")
                else ()
            )
            flow.add_encoding(
                f"{self.full_name}/output_scale:{output_index}",
                EncodingType.Scale,
                scalar=False,
                shape=(shape[-1],),
                initializer=scale_initializer,
            )
            flow.add_encoding(
                f"{self.full_name}/output_zero_point:{output_index}",
                EncodingType.ZeroPoint,
                scalar=False,
                shape=zp_shape,
                initializer=zp_initializer,
                quant=True,
                quant_min=tf.float32.min,
                quant_max=tf.float32.max,
            )

    def define_constraints(self, enc: EncodingSubOp):
        """
        Define the constraints between the encoding nodes.

        Example implementation:

        .. code-block::

            enc.identity(f'{self.full_name}/input_scale:0', f'{self.full_name}/output_scale:0')
            enc.identity(f'{self.full_name}/input_zero_point:0', f'{self.full_name}/output_zero_point:0')

        Args:
            enc (EncodingSubOp): atomic constraints to define relations between the encoding nodes

        """

    def define_const_constraints(self, enc: EncodingSubOp):
        """
        Define constant constraints between the encoding nodes.

        Args:
            enc (EncodingSubOp): atomic constraints to define relations between the encoding nodes

        """
        for input_index in range(self.num_inputs):
            enc.identity(f"{self.full_name}/input_scale:{input_index}", self.input_scales[input_index])
            enc.identity(f"{self.full_name}/input_zero_point:{input_index}", self.input_zero_points[input_index])
        for output_index in range(self.num_outputs):
            enc.identity(f"{self.full_name}/output_scale:{output_index}", self.output_scales[output_index])
            enc.identity(f"{self.full_name}/output_zero_point:{output_index}", self.output_zero_points[output_index])

    def update_encoding(self, encodings: Dict[str, Any]):
        """
        Update the atomic op's encodings.

        Args:
            encodings (dict): A dictionary of the form '<full_op_name>/<key_name>:<index>': encoding_value.
            Read define_encodings for additional info.

        """
        for input_index in range(self.num_inputs):
            self.input_scales[input_index] = encodings[f"{self.full_name}/input_scale:{input_index}"]
            self.input_zero_points[input_index] = encodings[f"{self.full_name}/input_zero_point:{input_index}"]
        for output_index in range(self.num_outputs):
            self.output_scales[output_index] = encodings[f"{self.full_name}/output_scale:{output_index}"]
            self.output_zero_points[output_index] = encodings[f"{self.full_name}/output_zero_point:{output_index}"]

    # endregion

    # region debuging
    def _verify_data_dtype(self, data_1, bit_width, signed, name):
        if self.debug_mode:
            verify_data_dtype(data_1, bit_width, signed, name)

    # endregion

    def _shift_right(self, value, shift, signed=False):
        if signed:
            return tf.sign(value) * tf.bitwise.right_shift(tf.abs(value), shift)
        else:
            return tf.bitwise.right_shift(value, shift)

    def bankers_round_with_shift(self, value, shift, bankers_round=3, signed=False):
        bankers_round = tf.cast(tf.math.minimum(shift, bankers_round), value.dtype)
        sign_x = tf.sign(value)
        if signed:
            to_bankers = tf.bitwise.right_shift(value, shift - bankers_round)
            bankers_unsignd = self._bankers_round_unsignd_shift(to_bankers, bankers_round)
            return bankers_unsignd
        else:
            to_bankers = tf.bitwise.right_shift(tf.abs(value), shift - bankers_round)
            bankers_unsignd = self._bankers_round_unsignd_shift(to_bankers, bankers_round)
            return sign_x * bankers_unsignd

    def _bankers_round_unsignd_shift(self, value, shift):
        shift = tf.cast(shift, value.dtype)
        floor_x = tf.bitwise.right_shift(value, shift)

        odd = floor_x % 2
        shift_numeric = 2**shift
        mod = value % shift_numeric
        thr = tf.cast(shift_numeric, tf.float32) / 2
        mod = tf.cast(mod, thr.dtype)

        equal = tf.cast((mod == thr), odd.dtype)
        greater = tf.cast(tf.math.greater(mod, thr), odd.dtype)
        to_round = greater + equal * odd

        return floor_x + to_round

    @staticmethod
    def signed_shift_bankers_rounding(input_num, shift_amount, n_bankers_rounding_bits=3):
        shift_amount = tf.cast(shift_amount, input_num.dtype)

        extended = tf.bitwise.left_shift(input_num, n_bankers_rounding_bits)  # padding with Zeros
        # doing the actual shift Arith
        shifted = tf.bitwise.right_shift(extended, shift_amount)
        resideu = tf.bitwise.bitwise_and(
            shifted,
            ((2**n_bankers_rounding_bits) - 1),
        )  # bitwise AND to extract the residue bits
        # git rid of the bankers bits since we already saved them.
        pre_rounding = tf.bitwise.right_shift(shifted, n_bankers_rounding_bits)
        # Flag high if value pre rounding is odd
        value_is_odd = tf.cast(pre_rounding & 0x1, tf.bool)
        # Flag high if residue greater than half roulding span
        residue_is_greater_than_half = resideu > (2 ** (n_bankers_rounding_bits - 1))
        # Flag high if residue equal to than half roulding span
        residue_is_equal_to_half = resideu == (2 ** (n_bankers_rounding_bits - 1))
        round_up_needed = tf.cast(
            tf.logical_or(residue_is_greater_than_half, tf.logical_and(residue_is_equal_to_half, value_is_odd)),
            input_num.dtype,
        )  # Flag high if round up needed
        post_rounding = pre_rounding + round_up_needed
        return post_rounding

    def hw_simulation_by_lossy_element(self, value, lossy_element, symmetric: bool = False):
        value = self._wrap_if_needed(value, lossy_element)

        bit_width = lossy_element.bits
        if lossy_element.signed:
            bit_width_signed = bit_width - 1
            max_val = (2**bit_width_signed) - 1
            min_val = -(2**bit_width_signed)
            if symmetric:
                min_val += 1

        else:
            max_val = (2**bit_width) - 1
            min_val = 0
        return tf.clip_by_value(value, min_val, max_val)

    def _wrap_if_needed(self, value, lossy_element):
        if not lossy_element.wraparound:
            return value
        bits = lossy_element.bits
        min_val = -(2 ** (bits - 1)) if lossy_element.signed else 0
        inp_as_uint = value - min_val
        wraps = tf.bitwise.right_shift(inp_as_uint, bits)
        wrapped_as_uint = inp_as_uint - wraps * 2**bits
        return wrapped_as_uint + min_val
