import time
from dataclasses import dataclass
from functools import partial
from typing import Tuple, Union

import numpy as np
import pwlf
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops._misc_internals import hailo_reciprocal
from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp, BaseWeightLossyElements
from hailo_model_optimization.acceleras.encoding.encoding_data import EncodingType
from hailo_model_optimization.acceleras.encoding.encoding_layer import TensorInitializer
from hailo_model_optimization.acceleras.lossy_elements.base_lossy_element import BaseLossyElement
from hailo_model_optimization.acceleras.lossy_elements.identity_element import IdentityElement
from hailo_model_optimization.acceleras.lossy_elements.quant_element import (
    BiasedDeltaQuantElement,
    QuantElement,
    RoundUpQuantElement,
)
from hailo_model_optimization.acceleras.statistics.statistics_base import TypeStats, update_stats
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    APU_CLIP_BITS_1,
    APU_CLIP_BITS_1_D,
    APU_CLIP_BITS_2,
    APU_CLIP_BITS_2_D,
    APU_EXP_BIAS_BITS,
    APU_EXP_BIAS_BITS_D,
    APU_EXP_BITS,
    APU_FINAL_SHIFT,
    APU_FINAL_SHIFT_D,
    APU_MANTISSA_BITS,
    APU_OFFSET_BITS,
    APU_OFFSET_BITS_D,
    DEFAULT_X_POINTS_MAX_VALUE,
    DUMMY_EXPONENT,
    POST_SHIFT_1_ROUNDING,
    ActivationFitPolicy,
    ActivationType,
    IgnoreHwLimitationAssertionPolicy,
    OptimizationTarget,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import (
    AccelerasImplementationError,
    AccelerasInitializationError,
    AccelerasNegativeSlopesError,
    AccelerasValueError,
)
from hailo_model_optimization.acceleras.utils.activation_definitions import (
    ACTIVATION_FITTING_PIECES,
    ACTIVATIONS_FITTING_SUPPORTED,
    ACTIVATIONS_TO_FIT,
    MUST_FIT_ACTIVATIONS,
    NATIVE_ONLY_ACTIVATION_TYPES,
    QUANTIZATION_SUPPORTED_ACTIVATION_TYPES,
    get_num_of_pieces,
)
from hailo_model_optimization.acceleras.utils.flow_state_utils import AtomicOpState
from hailo_model_optimization.acceleras.utils.quantization_group_utils import (
    _add_first_and_end_points,
    get_base_group_size,
    get_quantization_groups_info,
    get_split_size,
    split_to_quantize_groups,
)


@dataclass
class ActivationWeightsLossy(BaseWeightLossyElements):
    mantissa: BaseLossyElement
    mantissa_round_up: BaseLossyElement
    offset: BaseLossyElement
    clip1: BaseLossyElement
    clip2: BaseLossyElement
    thresholds: BaseLossyElement


class ActivationOp(BaseAtomicOp):
    """
    Represents Hailo APU (activation processing unit) -
        receiving L1/L2 ("accumulator") inputs, applying a scalar piecewise-linear function,
        and returning a L3 output which normally is also the output of the enclosing Layer.

       For a list of currently supported activation types, see the enum:
       hailo_model_optimization.acceleras.utils.acceleras_definitions.ActivationType

       NOTE - for this op there's a difference between **fully_native mode**,
            in which the original full-precision function is applied,
            and a **lossless mode**, in which piecewise-linear approximation is applied,
            (but no bits are thrown for either constants/inputs/intermediates).

    Attributes
        remove_offsets: if True, the offsets are removed from the activation function.
        based on the calibration values.

    """

    weight_lossy_elements: ActivationWeightsLossy
    REMOVE_OFFSET_FACTOR = 1.3
    num_inputs = 1
    num_outputs = 1
    remove_offsets = False

    def __init__(
        self,
        name: str,
        activation: Union[ActivationType, str, callable],
        logger=None,
        fully_native=True,
        bit_exact=None,
        **kwargs,
    ):
        super().__init__(name, logger=logger, fully_native=fully_native, bit_exact=bit_exact, **kwargs)
        self.weight_lossy_elements = ActivationWeightsLossy(
            mantissa=IdentityElement(name=f"{self.full_name}/ie:mantissa"),
            mantissa_round_up=IdentityElement(name=f"{self.full_name}/ie:mantissa_round_up"),
            offset=IdentityElement(name=f"{self.full_name}/ie:offset"),
            clip1=IdentityElement(name=f"{self.full_name}/ie:clip1"),
            clip2=IdentityElement(name=f"{self.full_name}/ie:clip2"),
            thresholds=IdentityElement(name=f"{self.full_name}/ie:thresholds"),
        )
        self.act_native_params = {}
        self.act_numeric_params = {}
        # when the activation is for a dense layer, we need to validate the shapes of quantization groups
        self.validate_shapes = False
        self._clip_range = None

        # some default params when there is no configuration yet.
        self.apu_final_shift = 0
        # a shift we need to do in a specific case where we in (16, 15) mode apu
        self.shift_data = 0
        self.apu_exp_bias = DUMMY_EXPONENT

        self._fit_policy = ActivationFitPolicy.allowed
        self.act_name, self.act_func = self.create_act_name_and_func(activation)
        self.thresholds = self.offsets = self.slopes = tuple()  # default slopes

        self.output_factor_by_group = np.ones(1, self.FLOAT_TYPE_NP)
        self.update_mantissa_exponent_decomposition()

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "name": self.full_name,
                "activation": self.act_name.value,
                "fully_native": self.fully_native,
                "bit_exact": self.bit_exact,
                "fit_policy": self._fit_policy.value,
                "remove_offsets": self.remove_offsets,
                "act_native_params": self.act_native_params,
                "act_numeric_params": self.act_numeric_params,
                "validate_shapes": self.validate_shapes,
                "clip_range": self._clip_range,
                "apu_final_shift": self.apu_final_shift,
                "shift_data": self.shift_data,
                "apu_exp_bias": self.apu_exp_bias,
                "thresholds": self.thresholds,
                "offsets": self.offsets,
                "slopes": self.slopes,
                "output_factor_by_group": self.output_factor_by_group,
                "exponent_factors_by_group": self.exponent_factors_by_group,
                "mantissas_candidate_by_group": self.mantissas_candidate_by_group,
                "exponents_by_group": self.exponents_by_group,
            }
        )
        return config

    @classmethod
    def from_config(cls, config):
        valid_kwargs = {
            "name": config.pop("name"),
            "activation": config.pop("activation"),
            "fully_native": config.pop("fully_native"),
            "bit_exact": config.pop("bit_exact"),
        }
        instance = cls(**valid_kwargs)

        for key, value in config.items():
            if key in cls.__dict__:
                setattr(instance, key, value)

        return instance

    # region Properties

    @property
    def homogeneous(self):
        return self.act_name in {ActivationType.LINEAR, ActivationType.RELU, ActivationType.LEAKY}

    @property
    def quantization_groups_num(self):
        return self.output_factor_by_group.shape[0]

    @property
    def final_shift_factor(self):
        return 2**-self.apu_final_shift

    # endregion

    def is_differentiable(self) -> bool:
        return self.act_name not in [ActivationType.BIASED_DELTA, ActivationType.DELTA]

    def is_negative_input(self) -> bool:
        """Return whether the input of the activation is only negative values"""
        if self.act_name in [ActivationType.MINUS_INV_POS] or (
            self.act_name in [ActivationType.INV_POS, ActivationType.INV_SQRT]
            and self.act_numeric_params["inverse_act_factor"] < 0
        ):
            stats = self.get_input_stats(0)
            efective_range = np.max(stats.max - stats.min)
            watested_range = np.min(-1 * stats.min)
            if not efective_range < watested_range:
                return True
        return False

    def create_weight_quant_element(self, optimization_target, **kwargs):
        apu_io_mode = self._get_apu_io_mode()

        (
            clip_bits1,
            clip_bits2,
            offset_bits,
            apu_exp_bias,
            apu_final_shift,
            shift_data,
        ) = self._get_hw_consts_from_io_apu_bits(apu_io_mode)

        # update exp_bias(rj and final shift( aka- beta)
        self.apu_exp_bias = apu_exp_bias
        self.apu_final_shift = apu_final_shift
        self.shift_data = shift_data
        threshold_bits = self.input_lossy_elements[0].bits
        if self.act_name in [ActivationType.BIASED_DELTA, ActivationType.DELTA, ActivationType.GREATER]:
            thresholds_quant_element = BiasedDeltaQuantElement(
                signed=True, bits=threshold_bits, wraparound=False, name=f"{self.full_name}/qe:thresholds"
            )
        else:
            thresholds_quant_element = QuantElement(
                signed=True,
                bits=threshold_bits,
                wraparound=False,
                name=f"{self.full_name}/qe:thresholds",
            )
        if optimization_target != OptimizationTarget.SAGE and optimization_target is not None:
            mantissa_signed = True
            mantissa_bits = APU_MANTISSA_BITS + 1  # extra bit for the sign
        else:
            mantissa_signed = False
            mantissa_bits = APU_MANTISSA_BITS
        weights_quant_elem = ActivationWeightsLossy(
            mantissa=QuantElement(
                signed=mantissa_signed,
                bits=mantissa_bits,
                wraparound=False,
                name=f"{self.full_name}/qe:mantissa",
            ),
            mantissa_round_up=RoundUpQuantElement(
                signed=False,
                bits=APU_MANTISSA_BITS,
                wraparound=False,
                name=f"{self.full_name}/qe:mantissa_round_up",
            ),
            offset=QuantElement(signed=True, bits=offset_bits, wraparound=False, name=f"{self.full_name}/qe:offset"),
            clip1=QuantElement(signed=True, bits=clip_bits1, wraparound=False, name=f"{self.full_name}/qe:clip1"),
            clip2=QuantElement(
                signed=self.output_lossy_element.signed,
                bits=clip_bits2,
                wraparound=False,
                name=f"{self.full_name}/qe:clip2",
            ),
            thresholds=thresholds_quant_element,
        )

        self.weight_lossy_elements = weights_quant_elem

        if apu_io_mode[0] == 32:
            self.INT_TYPE_TF = tf.int64

    @staticmethod
    def _get_hw_consts_from_io_apu_bits(apu_io_mode):
        """
        There are 4 apu_io_modes:
        (16, 8) - regular 8bits wights mode
        (32, 15) - regular 16bits(15) wights mode
        (32, 8) - Transition from 16bits to 8bits mode
        (16, 15) - Transition from 8bits to 16bits(15) mode

        """
        offset_bits = APU_OFFSET_BITS if apu_io_mode == (16, 8) else APU_OFFSET_BITS_D
        clip_bits1 = APU_CLIP_BITS_1 if apu_io_mode == (16, 8) else APU_CLIP_BITS_1_D
        clip_bits2 = APU_CLIP_BITS_2 if apu_io_mode == (16, 8) else APU_CLIP_BITS_2_D
        apu_exp_bias = APU_EXP_BIAS_BITS if apu_io_mode == (16, 8) else APU_EXP_BIAS_BITS_D
        apu_final_shift = APU_FINAL_SHIFT if apu_io_mode in ((32, 8), (20, 8)) else APU_FINAL_SHIFT_D

        shift_data = 8 if apu_io_mode == (16, 15) else 0
        if apu_io_mode == (20, 8):
            shift_data = 12

        return clip_bits1, clip_bits2, offset_bits, apu_exp_bias, apu_final_shift, shift_data

    def _get_apu_io_mode(self):
        output_bits = self.output_lossy_element.bits
        input_bits = self.input_lossy_element.bits
        return input_bits, output_bits

    @property
    def base_group_size(self):
        # TODO take the #groups from config, at least to validate?
        if self.quantization_groups_num == self.num_of_channels:
            return 1
        return get_base_group_size(self.quantization_groups_num, self.num_of_channels, self.validate_shapes)

    def get_slopes_count(self):
        if self.apply_fitting:
            return ACTIVATION_FITTING_PIECES
        return get_num_of_pieces(self.act_name, self.act_native_params, self.act_numeric_params)

    def import_weights(self, layer_params, **kwargs):
        self.act_native_params = {}
        self.act_numeric_params = {}
        if self.act_name == ActivationType.LEAKY:
            self.act_native_params["alpha"] = np.array(layer_params["leaky_alpha"])
            self.act_numeric_params["leaky_alpha"] = np.array(layer_params["leaky_alpha"])
        elif self.act_name == ActivationType.THRESHOLD:
            self.act_native_params["theta"] = np.array(layer_params["activation_threshold"])
        elif self.act_name == ActivationType.LESS:
            if "activation_less_values" in layer_params.keys():
                self.act_native_params["y"] = np.array(layer_params["activation_less_values"])
                self.act_numeric_params["activation_less_values"] = np.array(layer_params["activation_less_values"])
            else:
                self.act_native_params["y"] = np.array(0)
                self.act_numeric_params["activation_less_values"] = np.array(0)
        elif self.act_name == ActivationType.GREATER:
            if "activation_greater_values" in layer_params.keys():
                self.act_native_params["y"] = np.array(layer_params["activation_greater_values"])
                self.act_numeric_params["activation_greater_values"] = np.array(
                    layer_params["activation_greater_values"],
                )
            else:
                self.act_native_params["y"] = np.array(0)
                self.act_numeric_params["activation_greater_values"] = np.array(0)
        elif self.act_name == ActivationType.PRELU:
            self.act_native_params["prelu_slope"] = np.array(layer_params["prelu_slope"])
            self.act_numeric_params["prelu_slope"] = np.array(layer_params["prelu_slope"])
        elif self.act_name == ActivationType.BIASED_DELTA:
            if "activation_delta_bias" in layer_params.keys():
                self.act_native_params["activation_delta_bias"] = np.array(layer_params["activation_delta_bias"])
                self.act_numeric_params["activation_delta_bias"] = np.array(layer_params["activation_delta_bias"])
            else:
                self.act_native_params["activation_delta_bias"] = np.array(-1)
                self.act_numeric_params["activation_delta_bias"] = np.array(-1)
        elif self.act_name == ActivationType.SWISH:
            if "swish_beta" in layer_params.keys():
                self.act_native_params["swish_beta"] = np.array(layer_params["swish_beta"])
                self.act_numeric_params["swish_beta"] = np.array(layer_params["swish_beta"])
            else:
                self.act_native_params["swish_beta"] = np.array(1)
                self.act_numeric_params["swish_beta"] = np.array(1)
        elif self.act_name == ActivationType.HARDSIGMOID:
            self.act_native_params["hardsigmoid_alpha"] = np.array(layer_params["hardsigmoid_alpha"])
            self.act_numeric_params["hardsigmoid_alpha"] = np.array(layer_params["hardsigmoid_alpha"])
            self.act_native_params["hardsigmoid_beta"] = np.array(layer_params["hardsigmoid_beta"])
            self.act_numeric_params["hardsigmoid_beta"] = np.array(layer_params["hardsigmoid_beta"])
        elif self.act_name == ActivationType.CLIP:
            self.act_native_params["clip_min"] = np.squeeze(layer_params["clip_min"])
            self.act_numeric_params["clip_min"] = np.squeeze(layer_params["clip_min"])
            self.act_native_params["clip_max"] = np.squeeze(layer_params["clip_max"])
            self.act_numeric_params["clip_max"] = np.squeeze(layer_params["clip_max"])
        elif self.act_name in [ActivationType.INV_POS, ActivationType.INV_SQRT]:
            self.act_native_params["inverse_act_factor"] = np.array(layer_params.get("inverse_act_factor", 1))
            self.act_numeric_params["inverse_act_factor"] = np.array(layer_params.get("inverse_act_factor", 1))
        elif self.act_name == ActivationType.POW:
            self.act_native_params["pow_exponent"] = np.array(layer_params["pow_exponent"])
            self.act_numeric_params["pow_exponent"] = np.array(layer_params["pow_exponent"])
        elif self.act_name == ActivationType.PWL:
            thresholds = np.array(layer_params["thresholds"])
            offsets = np.array(layer_params["offsets"])
            slopes = np.array(layer_params["slopes"])
            if len(thresholds.shape) != 1 or len(offsets.shape) != 1 or len(slopes.shape) != 1:
                raise AccelerasInitializationError(
                    f"Invalid PWL activation parameters, expected 1D arrays, got {thresholds.shape}, {offsets.shape}, {slopes.shape}"
                )
            elif thresholds.shape[0] + 1 != offsets.shape[0] or thresholds.shape[0] + 1 != slopes.shape[0]:
                raise AccelerasInitializationError(
                    f"Invalid PWL activation parameters, expected thresholds to be 1 element shorter than offsets and slopes, got {thresholds.shape[0]}, {offsets.shape[0]}, {slopes.shape[0]}"
                )
            elif thresholds.shape[0] > 8:
                raise AccelerasInitializationError(
                    f"Invalid PWL activation parameters, expected to have maximum of 9 pieces, got {thresholds.shape[0] + 1}"
                )
            self.act_native_params["thresholds"] = thresholds
            self.act_native_params["offsets"] = offsets
            self.act_native_params["slopes"] = slopes
            self.act_numeric_params["thresholds"] = thresholds.copy()
            self.act_numeric_params["offsets"] = offsets.copy()
            self.act_numeric_params["slopes"] = slopes.copy()
        elif self.act_name in [ActivationType.EXP_DECOMPOSE, ActivationType.SHIFT]:
            if "mask" in layer_params.keys():
                self.act_native_params["mask"] = np.array(layer_params["mask"])
                self.act_numeric_params["mask"] = np.array(layer_params["mask"])

        self.act_name, self.act_func = self.create_act_name_and_func(self.act_name)

    def export_weights(self):
        """
        export the activation params for the layers. The weight should be returned as dict.
        Returns: dict of weights

        """
        if self.act_name == ActivationType.LEAKY:
            return {"leaky_alpha": self.act_native_params["alpha"]}
        elif self.act_name == ActivationType.THRESHOLD:
            return {"activation_threshold": self.act_native_params["theta"]}
        elif self.act_name == ActivationType.LESS:
            return {"activation_less_values": self.act_native_params["y"]}
        elif self.act_name == ActivationType.GREATER:
            return {"activation_greater_values": self.act_native_params["y"]}
        elif self.act_name == ActivationType.PRELU:
            return {"prelu_slope": self.act_native_params["prelu_slope"]}
        elif self.act_name == ActivationType.BIASED_DELTA:
            return {"activation_delta_bias": self.act_native_params["activation_delta_bias"]}
        elif self.act_name == ActivationType.SWISH:
            return {"swish_beta": self.act_native_params["swish_beta"]}
        elif self.act_name == ActivationType.HARDSIGMOID:
            return {
                "hardsigmoid_alpha": self.act_native_params["hardsigmoid_alpha"],
                "hardsigmoid_beta": self.act_native_params["hardsigmoid_beta"],
            }
        elif self.act_name == ActivationType.CLIP:
            return {"clip_min": self.act_native_params["clip_min"], "clip_max": self.act_native_params["clip_max"]}
        elif self.act_name in [ActivationType.INV_POS, ActivationType.INV_SQRT]:
            return {"inverse_act_factor": self.act_native_params["inverse_act_factor"]}
        elif self.act_name == ActivationType.POW:
            return {"pow_exponent": self.act_native_params["pow_exponent"]}
        elif self.act_name == ActivationType.PWL:
            return {
                "thresholds": self.act_native_params["thresholds"],
                "offsets": self.act_native_params["offsets"],
                "slopes": self.act_native_params["slopes"],
            }
        elif self.act_name in [ActivationType.EXP_DECOMPOSE, ActivationType.SHIFT]:
            layer_params = dict()
            mask = self.act_native_params.get("mask", None)
            if mask is not None:
                layer_params["mask"] = mask
            return layer_params
        else:
            return dict()

    @property
    def apply_fitting(self):
        return self._get_activation_fit_by_policy(self.act_name, self._fit_policy)

    def pl_approximate(self, accumulator_scale_candidate, optimization_target, utilize_wraparound=False):
        """
        Creating the generic piecewise-linear approximation of the activation function
        """
        apply_fitting = self.apply_fitting
        if self.act_name not in QUANTIZATION_SUPPORTED_ACTIVATION_TYPES:
            self._logger.debug(f"Don't support PL-approximation for {self.act_name.value} yet, forcing fully native")
            self.thresholds = self.offsets = self.slopes = tuple()
            self.thresholds = np.array(self.thresholds, dtype=self.FLOAT_TYPE_NP)
        elif apply_fitting:
            self._apply_activation_fitting(optimization_target)
        elif self.act_name == ActivationType.RELU:
            self.thresholds = (0.0,)
            self.offsets = (0.0, 0.0)
            self.slopes = (0.0, 1.0)
        elif self.act_name == ActivationType.LINEAR:
            self.thresholds = tuple()
            self.offsets = (0.0,)
            self.slopes = (1.0,)
        elif self.act_name == ActivationType.EXP:
            self.thresholds = (
                -4.082122906418818,
                -2.9239880997250993,
                -2.1378183236757744,
                -1.5450620171809504,
                -1.0622468763445756,
                -0.6515717357389873,
                -0.2966252623180404,
            )
            self.offsets = (
                0.03994322110379119,
                0.14077170431233604,
                0.2865104384326554,
                0.45555710809747246,
                0.6297955438794169,
                0.7923317480828821,
                0.9219184637672171,
                0.9943104796460943,
            )
            self.slopes = (
                0.006399217042686561,
                0.031099228659191284,
                0.08094167984222486,
                0.16001606506601576,
                0.27278722495761154,
                0.42579892856721674,
                0.6246821958675745,
                0.8687342881039356,
            )
        elif self.act_name == ActivationType.RELU6:
            self.thresholds = (0.0, 6.0)
            self.offsets = (0.0, 0.0, 6.0)
            self.slopes = (0.0, 1.0, 0.0)
        elif self.act_name == ActivationType.RELU1:
            self.thresholds = (0.0, 1.0)
            self.offsets = (0.0, 0.0, 1.0)
            self.slopes = (0.0, 1.0, 0.0)
        elif self.act_name == ActivationType.SIGMOID:
            self.thresholds = (
                -5.53733427,
                -3.09927295,
                -1.90616982,
                -0.99164017,
                0.99164017,
                1.90616982,
                3.09927295,
                5.53733427,
            )
            self.offsets = (
                0.0,
                0.08810424712922735,
                0.2644558437736866,
                0.4227523210348685,
                0.5,
                0.5772476789651317,
                0.7355441562263132,
                0.911895752870773,
                1.0,
            )
            self.slopes = (
                0.0,
                0.01627779998158698,
                0.07317875691229528,
                0.15622302484702835,
                0.23412192537778395,
                0.15622302484702819,
                0.07317875691229535,
                0.01627779998158692,
                0.0,
            )
        elif self.act_name == ActivationType.LEAKY:
            if len(self.act_numeric_params.keys()) == 0:
                self.act_native_params["alpha"] = 0.3
                self.act_numeric_params["leaky_alpha"] = 0.3
            self.thresholds = (0.0,)
            self.offsets = (0.0, 0.0)
            self.slopes = (self.act_numeric_params["leaky_alpha"], 1.0)
        elif self.act_name == ActivationType.ELU:
            self.thresholds = (-2.30258509, -1.2039728, -0.51082562, 0.0)
            self.offsets = (-0.8200000003908651, -0.4908193461953715, -0.18891032320148804, 0.0, 0.0)
            self.slopes = (0.039086503252366206, 0.18204784510466387, 0.4328085126163248, 0.8026222333954197, 1.0)
        elif self.act_name == ActivationType.TANH:
            self.thresholds = (-2.6210944, -1.61987821, -1.02180768, -0.53728883, 0.60000386, 1.25945738, 2.28178297)
            self.offsets = (
                -0.9965326429978993,
                -0.8375309346632831,
                -0.5254454857271249,
                -0.19696030088170047,
                -0.001078939203302709,
                0.27988605444820147,
                0.7239594024348832,
                0.9937275409836461,
            )
            self.slopes = (
                0.00034673570021006206,
                0.06100907137748144,
                0.25366889420216054,
                0.5751434644905972,
                0.9397171364919261,
                0.4714451596224705,
                0.11885415095595916,
                0.0006272459016353909,
            )
        elif self.act_name == ActivationType.SOFTPLUS:
            self.thresholds = (-4.0, -2.33749224, -1.26575537, -0.40175493, 0.40645237, 1.27026613, 2.34072059, 4.0)
            self.offsets = (
                0.0169623,
                0.1814718105294551,
                0.4182164501261004,
                0.6210950945764345,
                0.6999223275213164,
                0.6202472654460531,
                0.41729326541731315,
                0.1812120233895751,
                0.010219049999999896,
            )
            self.slopes = (
                0.00169623,
                0.04282360763236377,
                0.14410507310437123,
                0.30438773850624434,
                0.5005949958630662,
                0.6966205771021754,
                0.856393395754547,
                0.9572517566526062,
                1.0,
            )
        elif self.act_name == ActivationType.SILU or self.act_name == ActivationType.SWISH:
            if optimization_target != OptimizationTarget.SAGE:
                # allows negative slopes
                self.thresholds = [
                    -6.0,
                    -3.89303964,
                    -1.3473389,
                    -0.63382168,
                    -0.02343757,
                    0.59187625,
                    1.32987503,
                    3.89463346,
                ]
                self.slopes = [
                    0.0,
                    -0.02893179,
                    -0.08676027,
                    0.07889137,
                    0.33924308,
                    0.63803194,
                    0.91142625,
                    1.08662166,
                    1.0,
                ]
                self.offsets = [
                    0.0,
                    -0.18177579,
                    -0.40690438,
                    -0.18371548,
                    -0.01869892,
                    -0.01169604,
                    -0.17351164,
                    -0.40649964,
                    0.0,
                ]
            else:
                self.thresholds = (-5.0, -3.388, -2.339, -1.0, -0.282, 0.384, 1.127, 5.0)
                self.offsets = (
                    0.0,
                    -0.066,
                    -0.156,
                    -0.263,
                    -0.09169080779944291,
                    0.00819819819819817,
                    -0.11301480484522211,
                    -0.37493183578621236,
                    0.0,
                )
                self.slopes = (
                    0.0,
                    0.0,
                    0.0,
                    0.0,
                    0.1713091922005571,
                    0.5255255255255256,
                    0.8411843876177658,
                    1.0735863671572425,
                    1.0,
                )
            if self.act_name == ActivationType.SWISH:
                if len(self.act_numeric_params.keys()) == 0:
                    self.act_numeric_params["swish_beta"] = 1
                    self.act_native_params["swish_beta"] = 1
                beta = self.act_numeric_params["swish_beta"]

                if beta == 0:
                    self.thresholds = tuple()
                    self.offsets = (0.0,)
                    self.slopes = (0.5,)
                else:
                    self.thresholds = tuple(x / beta for x in self.thresholds)
                    self.offsets = tuple(x / beta for x in self.offsets)
                    if beta < 0:
                        self.thresholds = tuple(reversed(self.thresholds))
                        self.offsets = tuple(reversed(self.offsets))
                        self.slopes = tuple(reversed(self.slopes))
        elif self.act_name == ActivationType.BIASED_DELTA:
            if len(self.act_numeric_params.keys()) == 0:
                self.act_native_params["activation_delta_bias"] = -1
                self.act_numeric_params["activation_delta_bias"] = -1
            self.thresholds = (-0.000001, 0.000001)
            self.offsets = (
                self.act_numeric_params["activation_delta_bias"],
                0.0,
                self.act_numeric_params["activation_delta_bias"],
            )
            self.slopes = (0.0, 0.0, 0.0)
        elif self.act_name == ActivationType.DELTA:
            self.thresholds = (-0.000001, 0.000001)
            self.offsets = (0, 1, 0)
            self.slopes = (0.0, 0.0, 0.0)
        elif self.act_name == ActivationType.LESS:
            if len(self.act_native_params.keys()) == 0:
                self.act_native_params["y"] = 0
                self.act_numeric_params["activation_less_values"] = 0
            # we support only scalar th in the HW, if the layer is native
            # only, thresholds will be ignored
            self.thresholds = [np.mean(self.act_native_params["y"])]
            self.offsets = (1.0, 0.0)
            self.slopes = (0.0, 0.0)
        elif self.act_name == ActivationType.GREATER:
            if len(self.act_native_params.keys()) == 0:
                self.act_native_params["y"] = 0
                self.act_numeric_params["activation_greater_values"] = 0
            # we support only scalar th in the HW, if the layer is native
            # only, thresholds will be ignored
            self.thresholds = [np.mean(self.act_native_params["y"]) + 0.000001]
            self.offsets = (0.0, 1.0)
            self.slopes = (0.0, 0.0)
        elif self.act_name == ActivationType.THRESHOLD:
            if "theta" not in self.act_native_params:
                self.act_native_params["theta"] = 1
            self.act_numeric_params["activation_threshold"] = 1
            self.thresholds = [self.act_native_params["theta"]]
            self.offsets = (0.0, 0.0)
            self.slopes = (0.0, 1.0)
        elif self.act_name == ActivationType.HARDSIGMOID:
            if len(self.act_numeric_params.keys()) == 0:
                self.act_numeric_params["hardsigmoid_alpha"] = 1
                self.act_numeric_params["hardsigmoid_beta"] = 0.5

            beta = self.act_numeric_params["hardsigmoid_beta"]
            alpha = self.act_numeric_params["hardsigmoid_alpha"]

            self.thresholds = (-beta / alpha, (1.0 - beta) / alpha)
            self.offsets = (0.0, beta, 1.0)
            self.slopes = (0.0, alpha, 0.0)
        elif self.act_name == ActivationType.GELU:
            if optimization_target != OptimizationTarget.SAGE:
                # allows negative slopes
                self.thresholds = (-3.62189, -2.2877, -0.80275, -0.37682, -0.0012, 0.37014, 0.79108, 2.36739)
                self.offsets = (0, -0.04368, -0.26457, -0.11051, -0.00921, -0.00886, -0.10773, -0.26145, -0.02197)
                self.slopes = (0, -0.01206, -0.10862, 0.0833, 0.35214, 0.64485, 0.91198, 1.10629, 1.00514)
            else:
                self.thresholds = (
                    -5.0,
                    -2.07481226,
                    -1.39832816,
                    -0.5,
                    0.06593052252088966,
                    0.6542365682250897,
                    2.4044649501376,
                )
                self.offsets = (
                    0.0,
                    -0.005706,
                    -0.0731837,
                    -0.15300025,
                    -0.007825997314101968,
                    -0.03677881345848943,
                    -0.25332852321526356,
                    -0.020574868020072138,
                )
                self.slopes = (
                    0,
                    0,
                    0,
                    0,
                    0.33146264305828643,
                    0.7706039548113364,
                    1.1015999893868147,
                    1.0047993871647396,
                )
        elif self.act_name == ActivationType.HARDSWISH:
            if optimization_target != OptimizationTarget.SAGE:
                # allows negative slopes
                self.thresholds = (-3.06885, -2.01602, -1.00132, 0.00917, 1.0155, 2.02181, 3.06805)
                self.offsets = (0.0, -1.04722, -0.36481, -0.02669, -0.02977, -0.37014, -1.04935, 0.0)
                self.slopes = (0.0, -0.34124, -0.00275, 0.33493, 0.67091, 1.00608, 1.34202, 1.0)
            else:
                self.thresholds = (
                    -5.0,
                    -2.79068793,
                    -2.29763961,
                    -1.0,
                    0.34967460211345386,
                    1.7031505404857883,
                    3.1313349911703496,
                )
                self.offsets = (
                    0.0,
                    -0.00475815,
                    -0.19005807,
                    -0.34790991,
                    0.007882841466477386,
                    -0.15102804428818295,
                    -0.9200629961043796,
                    0,
                )
                self.slopes = (
                    0,
                    0,
                    0,
                    0,
                    0.39177922706030094,
                    0.8430048488085773,
                    1.2938928740154652,
                    1.0000000247949148,
                )
        elif (self.act_name == ActivationType.MINUS_INV_POS) or (
            self.act_name == ActivationType.INV_POS and self.act_native_params["inverse_act_factor"] < 0
        ):
            self.thresholds = (
                -30.898276640325783,
                -17.160624926061416,
                -9.972874297447543,
                -6.0966997881327165,
                -3.897010109467108,
                -2.5583847372716857,
                -1.7281572472978795,
            )
            self.offsets = (
                0.04625268484641968,
                0.08641456372450479,
                0.15212564687007252,
                0.25602278923614236,
                0.4091856499535541,
                0.6308744001342965,
                0.9499948456465039,
                1.3772920466362835,
            )
            self.slopes = (
                0.0005158008339733299,
                0.0018156105076840398,
                0.00564478825783543,
                0.016062761980081313,
                0.04118501927337049,
                0.09807190022915455,
                0.22280702738306496,
                0.47006299998503454,
            )
        elif self.act_name == ActivationType.INV_SQRT:
            # This calculation is based on the assumption that the optimal solution for the pwla is continuous.
            # Under this assumption, our thresholds are a geometric sequence with ratio r = (min/max)**(1/num_of_slope).
            number_of_slopes = 8 if utilize_wraparound else 9
            preact_stats = self.get_input_stats(0)
            max_value = np.max(preact_stats.max)
            min_value = np.min(preact_stats.min)
            # We clip the ratio to insures that last slope is not too steep so we could avoid fatal negative exponent.
            # The max_ratio was approximated for 9 and 8 slopes such that the resulted threshold would be similar to
            # ones that were calculated by activation_fitting.
            max_ratio_9_slopes = 3.175
            max_ratio_8_slopes = 3.744
            max_ratio = max_ratio_9_slopes if number_of_slopes == 9 else max_ratio_8_slopes
            ratio = np.minimum(max_ratio, (min_value / max_value) ** (1 / number_of_slopes))
            factor = self.act_native_params["inverse_act_factor"]

            # This solution minimize the l2 loss of the function at the range [break_points[i] * ratio, break_points[i]]
            break_points = min_value / (ratio ** (np.arange(number_of_slopes) + 1))
            self.thresholds = tuple(break_points[:-1])
            self.slopes = tuple(-(4 / break_points) / (np.sqrt(factor * break_points) * (1 + np.sqrt(ratio)) ** 3))
            self.offsets = tuple(
                4 * (1 + np.sqrt(ratio) + ratio) / (np.sqrt(factor * break_points) * (1 + np.sqrt(ratio)) ** 3)
            )
        elif self.act_name == ActivationType.CLIP:
            if len(self.act_numeric_params.keys()) == 0:
                self.act_numeric_params["clip_min"] = -1
                self.act_numeric_params["clip_max"] = 1
            self.thresholds = (self.act_numeric_params["clip_min"], self.act_numeric_params["clip_max"])
            self.offsets = (self.act_numeric_params["clip_min"], 0, self.act_numeric_params["clip_max"])
            self.slopes = (0, 1, 0)
            if self.act_numeric_params["clip_min"] == -np.inf:
                self.thresholds = self.thresholds[1:]
                self.offsets = self.offsets[1:]
                self.slopes = self.slopes[1:]
            if self.act_numeric_params["clip_max"] == np.inf:
                self.thresholds = self.thresholds[:-1]
                self.offsets = self.offsets[:-1]
                self.slopes = self.slopes[:-1]
        elif self.act_name == ActivationType.MISH:
            if optimization_target != OptimizationTarget.SAGE:
                # allows negative slopes
                self.thresholds = (-7.88799, -5.17976, -3.57967, -1.23905, -0.54349, 0.0408, 0.65957, 2.6477)
                self.offsets = (0, -0.06831, -0.24449, -0.44588, -0.19312, -0.01091, -0.02584, -0.19933, -0.03687)
                self.slopes = (0, -0.00866, -0.04267, -0.09893, 0.10507, 0.44033, 0.80639, 1.06942, 1.00806)
            else:
                self.thresholds = (-5.0, -3.314, -2.248, -1.0, -0.21, 0.541, 2.72, 5.0)
                self.offsets = (0.0, -0.068, -0.169, -0.281, -0.10378481, 0.00608389, -0.19125562, -0.0342807, 0.0)
                self.slopes = (0.0, 0.0, 0.0, 0.0, 0.17721519, 0.70039947, 1.06516751, 1.00745614, 1.0)
        elif self.act_name == ActivationType.LOG:
            self.thresholds = (
                0.07794553037629727,
                0.1938189162312083,
                0.41324522734184127,
                0.791286344484762,
                1.4012354720097504,
                2.3179230522819476,
                3.5528835260901928,
            )
            self.offsets = (
                -4.088927261456292,
                -3.0627204956169862,
                -2.241621583207326,
                -1.5438001509456627,
                -0.9366866038742211,
                -0.3994256666323819,
                0.06848375523178496,
                0.4494249564496442,
            )
            self.slopes = (
                20.784167315172787,
                7.6184763398841255,
                3.382053350963042,
                1.693415742852405,
                0.9261668811691447,
                0.5427474288544943,
                0.3408816156630169,
                0.23366132588417635,
            )
        elif self.act_name == ActivationType.SOFTSIGN:
            self.slopes = (
                0.035778508856313715,
                0.11191418754047931,
                0.31168408582069423,
                0.7858535893689282,
                0.37962274332782286,
                0.17737589789972943,
                0.0781091982405676,
                0.03185518120021198,
            )
            self.offsets = (
                -0.6518877135661099,
                -0.43490475366430625,
                -0.1842943702445171,
                0.0012988933432609984,
                0.14047550124623503,
                0.32967226844486086,
                0.5149580096686003,
                0.6716546287693659,
            )
            self.thresholds = (
                -2.8499510827494707,
                -1.2544952246422079,
                -0.3914070015025733,
                0.34260472649802437,
                0.9354745029429521,
                1.8665447915557698,
                3.3877407656085583,
            )
        elif self.act_name == ActivationType.HDR_COMPRESSION:
            self.slopes = (0.5, 0.25, 0.125, 0.03125)
            self.offsets = (0, 4096, 8192, 32768)
            self.thresholds = (2**14, 2**15, 2**18)
        elif self.act_name == ActivationType.PWL:
            self.slopes = tuple(self.act_native_params["slopes"])
            self.offsets = tuple(self.act_native_params["offsets"])
            self.thresholds = tuple(self.act_native_params["thresholds"])
        elif self.act_name == ActivationType.EXP_DECOMPOSE:
            mask = self.act_native_params["mask"]
            larger_mask = np.maximum(mask, -mask[::-1])[len(mask) // 2 :]
            self.thresholds = tuple(mask[1:-1])
            self.offsets = tuple(3 * np.concatenate([larger_mask[::-1], larger_mask[1:]]))
            self.slopes = tuple(0.0 for _ in range(len(self.thresholds) + 1))
        elif self.act_name == ActivationType.SHIFT:
            mask = self.act_native_params["mask"]
            larger_mask = np.maximum(mask, -mask[::-1])[len(mask) // 2 :]
            buffer = larger_mask[-1] / 2**15  # small buffer that separate between thresholds.
            self.thresholds = tuple(4 * larger_mask[:-1] - buffer)
            self.offsets = tuple(-3.0 for _ in range(len(self.thresholds) + 1))
            self.slopes = tuple(1.0 / larger_mask)
        else:
            raise ValueError(f"Failed PL-approximation, unexpected activation {self.act_name} in op {self.full_name}")

        if self.remove_offsets:
            self._crop_offsets()

        if utilize_wraparound:
            if len(self.thresholds) == 8:
                raise AccelerasInitializationError(
                    f"wraparound utilization is needed but number of segment exceed maximum capacity at layer {self.full_name}."
                )
            limval = np.max(accumulator_scale_candidate) * self.input_lossy_elements[0].max_value
            self.thresholds = (
                tuple(threshold for threshold in self.thresholds)
                + (0.0,)
                + tuple(threshold + 2 * limval for threshold in self.thresholds)
            )
            self.offsets = tuple(offset for offset in self.offsets) + tuple(
                offset - 2 * limval * slope for offset, slope in zip(self.offsets, self.slopes)
            )
            self.slopes = tuple(slope for slope in self.slopes) + tuple(slope for slope in self.slopes)
            indices = [index for index, threshold in enumerate(self.thresholds) if -limval < threshold < limval]
            start_index, end_index = indices[0], indices[-1] + 1
            self.thresholds = self.thresholds[start_index:end_index]
            self.offsets = self.offsets[start_index : end_index + 1]
            self.slopes = self.slopes[start_index : end_index + 1]

        self.thresholds = np.array(self.thresholds, dtype=self.FLOAT_TYPE_NP)

    def _get_activation_fit_by_policy(self, activation_name, activation_fit_policy):
        """
        this function return a bool that indicates to use the fitting activation optimization or not according to the activation_fit_policy
        parameter and the activation_name.

        Args:
            activation_name: the activation name type
            activation_fit_policy: the enum from the script

        Returns: Boolean

        """
        if activation_fit_policy == ActivationFitPolicy.allowed:
            if activation_name in ACTIVATIONS_TO_FIT:
                activation_fit_policy = ActivationFitPolicy.enabled
            else:
                activation_fit_policy = ActivationFitPolicy.disabled
        apply_fitting = ActivationFitPolicy.enabled == activation_fit_policy
        return apply_fitting

    def set_fit_policy(self, fit_policy: ActivationFitPolicy):
        apply_fitting = self._get_activation_fit_by_policy(self.act_name, fit_policy)
        # validate activation fitting (maybe need tos be in another place?)
        if apply_fitting and self.act_name not in ACTIVATIONS_FITTING_SUPPORTED:
            raise AccelerasImplementationError(
                f"Acceleras doesn't support activation fit on {self.act_name} in layer {self.full_name}",
            )

        if not apply_fitting and self.act_name in MUST_FIT_ACTIVATIONS:
            raise AccelerasImplementationError(
                f"Acceleras requires activation fit on {self.act_name} in layer {self.full_name}",
            )
        self._fit_policy = fit_policy

    def create_output_encoding_candidates(
        self,
        output_index,
        forced_range=None,
        output_lossy_external=None,
        translation_config=None,
        split_precision_zp=None,
    ):
        if self.act_name == ActivationType.BIASED_DELTA:
            output_lossy_element = self.output_lossy_elements[output_index]
            output_zp = 2**output_lossy_element.bits - 1
            output_scale = np.array(-self.act_numeric_params["activation_delta_bias"] / output_zp, self.FLOAT_TYPE_NP)
            output_channels = self.output_shape[-1]
            output_scale = np.repeat(output_scale, output_channels)
            self.output_scales[output_index] = output_scale
            self.output_zero_points[output_index] = output_zp
        elif self.act_name == ActivationType.DELTA:
            output_lossy_element = self.output_lossy_elements[output_index]
            output_zp = 0
            output_scale = np.array(1 / (2**output_lossy_element.bits - 1), self.FLOAT_TYPE_NP)
            output_channels = self.output_shape[-1]
            output_scale = np.repeat(output_scale, output_channels)
            self.output_scales[output_index] = output_scale
            self.output_zero_points[output_index] = output_zp
        else:
            super().create_output_encoding_candidates(
                output_index,
                forced_range,
                output_lossy_external,
                translation_config=translation_config,
                split_precision_zp=split_precision_zp,
            )

    def create_hw_params(
        self, accumulator_scale_candidate, optimization_target, nudging=True, utilize_wraparound=False
    ):
        """
        makes some basic preparations for encodings creation
         accepting accumulator scale candidate and creating output_factor (@slope=1) candidate.

        NOTE: actual non-linearity and HW representation done later in enforce_encoding,
              using the finalized accumulator encodings.
        thus, TODO consider renaming so that to not suggest "full quant" semantics which is misleading..

        TODO: quantization groups support (created fully scalar as of now)
        """
        if self.act_name in NATIVE_ONLY_ACTIVATION_TYPES:
            raise AccelerasImplementationError(
                f"Activation {self.act_name} has no quantized implementation in acceleras. Only native computation is supported.",
            )
        if self.act_name not in QUANTIZATION_SUPPORTED_ACTIVATION_TYPES:
            raise AccelerasImplementationError(
                f"acceleras doesn't support quantization of {self.act_name} activation yet",
            )
        if (
            self.act_name in [ActivationType.LESS, ActivationType.GREATER]
            and len(np.array(self.act_native_params["y"]).shape) > 0
        ):
            if len(np.unique(self.act_native_params["y"])) != 1:
                raise AccelerasImplementationError(
                    f"acceleras doesn't support quantization of {self.act_name} with vector thresholds",
                )
        if utilize_wraparound and not self.is_negative_input():
            raise AccelerasValueError(
                f"activation {self.act_name} cannot utilize accumulator wraparound as the expected input might be positive",
            )
        self.pl_approximate(accumulator_scale_candidate, optimization_target, utilize_wraparound=utilize_wraparound)
        output_factors_val = accumulator_scale_candidate / self.output_scale
        self.output_factor_by_group = self._get_output_factor_by_group(output_factors_val, nudging=nudging)
        # TODO: Currently we cant train output_factors and there for we can decompose the slopes at creation time.
        self.update_mantissa_exponent_decomposition()
        self.enforce_encoding()

    def _get_output_factor_by_group(self, output_factors_val, eps=1e-3, nudging=True):
        """
        get the output factors by the group size (if there groups 4 groups then 4 output factors)
        This automatically covers the "nudging" making ReLU&linear lossless (precise mantissa).

        Args:
            output_factors_val (tensor): vector of the output factors
            eps (float, optional): set the threshold we allow . Defaults to 1e-3.
            nudging (bool, optional): enable nudging of output_factors when exist a single slope != 0. Defaults to True.

        Returns:
            list: a list of output factors

        """
        to_check = isinstance(output_factors_val, np.ndarray) and output_factors_val.shape != ()

        split_dim = -1
        _, split_points, _ = get_quantization_groups_info(
            self.quantization_groups_num,
            output_factors_val.shape[split_dim],
            self.validate_shapes,
        )

        output_factors_by_group = []
        for i in range(self.quantization_groups_num):
            output_factors_part = tf.gather(
                output_factors_val,
                indices=np.arange(split_points[i], split_points[i + 1]).astype("int"),
                axis=split_dim,
            )
            mean_output_factor = np.mean(output_factors_part)
            if to_check:
                diff = np.max(np.abs(output_factors_part - mean_output_factor) / mean_output_factor)
                if eps < diff:
                    self._logger.warning(f"the output_factor should be a scalar {diff}")
            output_factors_by_group.append(mean_output_factor)
        output_factors_by_group = np.array(output_factors_by_group, self.FLOAT_TYPE_NP)
        # Perform the HW-compatible decomposition for single slope != 0.
        # This is under the assumption that enable_lossy was called.
        if nudging and np.count_nonzero(self.slopes) == 1:
            nonzero = self.slopes[np.flatnonzero(self.slopes)[0]]
            exponent_factors, mantissas_candidate, exponents = self._get_mantissa_exponent_decomposition(
                np.array([nonzero], dtype=self.FLOAT_TYPE_NP),
                output_factors_by_group,
            )
            mantissas = self.weight_lossy_elements.mantissa_round_up(mantissas_candidate)
            # Compute the actual HW overall factor assuming identity-activation (slope=1)
            output_factors_by_group = tf.squeeze(
                exponent_factors * mantissas * self.final_shift_factor / nonzero,
                axis=0,
            )
        return output_factors_by_group

    def force_input_range(self, force_limvals, index):
        stats = self.get_input_stats(index)
        new_min = np.ones_like(stats.min) * force_limvals[0]
        new_max = np.ones_like(stats.max) * force_limvals[1]
        update_stats(stats, new_min, TypeStats.MIN, clear_cannot_update=True)
        update_stats(stats, new_max, TypeStats.MAX, clear_cannot_update=True)

    def calc_min_input_scale(self):
        """
        return the min input scale (aka- accumulator scale) for every channel
        """
        default_max = 1e-5  # Epsilon Value that enables slopes.
        abs_max_per_channel = np.maximum(np.abs(self.get_input_stats(0).max), np.abs(self.get_input_stats(0).min))
        abs_max_per_channel[abs_max_per_channel == 0.0] = default_max
        max_quant_value = self.input_lossy_elements[0].max_value
        return abs_max_per_channel / max_quant_value

    def get_assigned_exponent(self, exponent=None):
        if exponent is None:
            exponent = -self.exponents
        assigned_exp = exponent - self.apu_exp_bias
        return assigned_exp

    def is_valid_exponent(self):
        exponent = self.get_assigned_exponent()
        min_exp = 0
        max_exp = 2**APU_EXP_BITS - 1
        return np.all((min_exp <= exponent) & (exponent <= max_exp))

    def assertion_negative_slope(self):
        if self._ignore_hw_limitation_assertion == IgnoreHwLimitationAssertionPolicy.enabled:
            return False
        else:
            return True

    def check_exp_range(self):
        """
        Args:
            exp_cand - numpy array
        Returns:
            array of booleans

        """
        if self._ignore_hw_limitation_assertion == IgnoreHwLimitationAssertionPolicy.enabled:
            return
        if not self.is_valid_exponent():
            fixed_exponent = self.apu_exp_bias
            exponents = self.get_assigned_exponent() + fixed_exponent
            valid_range = np.array([0, 2**APU_EXP_BITS - 1]) + fixed_exponent
            raise AccelerasNegativeSlopesError(exponents.numpy(), valid_range)

    def get_accumulator_scale(self):
        """
        Computing the real scale of input to activation (aka, accumulator),
         given the actual (post-quant) rescale factor as will be used in HW
        This is done by invoking the mantissa/exponent decomposition with slope=1.
        """
        # TODO when fully_native/pl_available skip hw-compatibility
        self.input_scales[0] = self.output_scale * self.output_factors
        return self.input_scales[0]

    def set_quantization_groups(self, quantization_groups):
        """
        set quantization_groups and reset base_group_size and output_factor_by_group
        """
        if quantization_groups is None:
            quantization_groups = 1
        self.output_factor_by_group = self.FLOAT_TYPE_NP([1] * quantization_groups)

    def update_mantissa_exponent_decomposition(self):
        """
        Update exponent_factors, mantissas_candidate and exponents decomposition after `self.slopes` or
        `self.output_factor_by_group` has changed.
        """
        (
            self.exponent_factors_by_group,
            self.mantissas_candidate_by_group,
            self.exponents_by_group,
        ) = self._get_mantissa_exponent_decomposition(
            np.array(self.slopes, dtype=self.FLOAT_TYPE_NP),
            self.output_factor_by_group,
        )

    def _get_mantissa_exponent_decomposition(self, slopes, output_factors):
        """
        this method gets the slopes and output factors and calculates the
        exponent_factors, and mantissa_candidates.
        get the slopes and calculators the connection between the slopes and the output_scale
        Args:
            slopes: an array of slopes
            output_factors: the output_factors  for each channel
        Returns: exponent_factors - 2**exponents
                 mantissas_candidate - the mantissa we use to represent slopes (with no rounding)
                 exponents - the exponents we use to represent the slopes
                 all returned values are in the shape of (number_of_pieces, number_of_channels)
        """
        # support channel-wise by using 2-D matrices (PL-piece-ind x channel-ind) along the lines of:
        slopes = tf.expand_dims(slopes, 1)  # making it a column
        # making it a row - this is a
        rescale_factor = tf.expand_dims(output_factors, 0)
        # matrix of shape (num_pieces, num_quantization_groups)
        factor_candidates = tf.matmul(slopes, rescale_factor) / self.final_shift_factor

        # Handle zero slope:
        dummy_factor_candidates = 2 ** (-self.apu_exp_bias - 1 + APU_MANTISSA_BITS)
        non_zero_factor_candidates = tf.where(slopes != 0, factor_candidates, dummy_factor_candidates)

        exponents = tf.math.ceil(
            tf.experimental.numpy.log2(tf.math.abs(non_zero_factor_candidates)) - APU_MANTISSA_BITS
        )

        # TODO fix edge case where mantissa is 2**mantissa_bits. - check edge cases (what if mantissa rounds to 1024? there are more?)
        exponent_factors_inv = tf.math.pow(2.0, -1.0 * exponents)
        mantissas_candidate = factor_candidates * exponent_factors_inv

        # if the input is 16 bit we will assume it is "32" (same as 32->16 where the input is shifted by 8bits)
        exponents -= self.shift_data

        # fix upper overflow in the exponent
        max_exponent = self.apu_exp_bias + (2**APU_EXP_BITS - 1)
        mantissas_candidate = mantissas_candidate * 2 ** tf.math.minimum(exponents + max_exponent, 0)
        exponents = tf.math.maximum(exponents, -max_exponent)

        # if the mantissa is 1024 div by 2
        mantissa_overflow = mantissas_candidate >= 2**APU_MANTISSA_BITS - 0.5
        exponents = tf.where(mantissa_overflow, exponents + 1, exponents)
        mantissas_candidate = tf.where(mantissa_overflow, mantissas_candidate / 2, mantissas_candidate)
        exponent_factors = tf.math.pow(2.0, exponents + self.shift_data)

        return exponent_factors, mantissas_candidate, exponents

    def _group_to_vector(self, element_by_group):
        result = tf.cast(tf.repeat(element_by_group, self.base_group_size, axis=-1), self.FLOAT_TYPE_TF)
        return tf.gather(result, indices=np.arange(self.num_of_channels).astype("int"), axis=-1)

    @property
    def output_factors(self):
        return self._group_to_vector(self.output_factor_by_group)

    @property
    def exponent_factors(self):
        return self._group_to_vector(self.exponent_factors_by_group)

    @property
    def mantissas_candidate(self):
        return self._group_to_vector(self.mantissas_candidate_by_group)

    @property
    def exponents(self):
        return self._group_to_vector(self.exponents_by_group)

    def enforce_encoding(self, training=False, zp_factor=None):
        """
        Does all the heavy lifting of PL specialization for the I/O encodings,
        to arrive at slopes (and their mantissa/exponent decomposition) and offsets.
        """
        total_slopes = (
            self.weight_lossy_elements.mantissa(self.mantissas_candidate, training=training) * self.exponent_factors
        )
        output_zero_point = self.output_zero_point if zp_factor is None else self.output_zero_point // zp_factor
        self.output_zero_point = tf.convert_to_tensor(self.output_zero_point)
        self.offset_candidates, self.real_offsets, self.acc_thresholds = self._get_pieces_encoding(
            total_slopes,
            self.input_scale,
            self.input_zero_point,
            self.output_scale,
            output_zero_point,
            training=training,
        )

    def _get_pieces_encoding(
        self,
        total_slopes,
        input_scale,
        input_zero_point,
        output_scale,
        output_zero_point,
        training=False,
    ):
        # in case of a biased delta activation, convert the pieces thresholds to 'half a bin'
        th = tf.reduce_min(input_scale) / 2
        if self.act_name in [ActivationType.BIASED_DELTA, ActivationType.DELTA]:
            self.thresholds = (-th, th)
        elif self.act_name == ActivationType.GREATER:
            self.thresholds = (np.mean(self.act_native_params["y"]) + th,)
        # add -inf and inf to the thresholds
        thresholds = tf.concat(([-np.inf], tf.cast(self.thresholds, self.FLOAT_TYPE_TF), [np.inf]), axis=0)

        thresholds = tf.cast(tf.expand_dims(thresholds, 1), self.FLOAT_TYPE_TF)
        offsets = tf.cast(tf.expand_dims(self.offsets, 1), self.FLOAT_TYPE_TF)
        output_scale_row = tf.cast(tf.expand_dims(output_scale, 0), self.FLOAT_TYPE_TF)
        input_scale_row = tf.cast(tf.expand_dims(input_scale, 0), self.FLOAT_TYPE_TF)
        input_zero_point_row = tf.cast(tf.expand_dims(input_zero_point, 0), self.FLOAT_TYPE_TF)

        acc_thresholds = (thresholds / input_scale_row) + input_zero_point_row
        # Build the offset candidate (first at output scale) from the 3 components:
        # (A) The "b" of y=ax+b in PL (B) Compensating input and output zero points
        offsets_at_outp_scale = offsets / output_scale_row
        offsets_at_outp_scale -= tf.cast(input_zero_point, self.FLOAT_TYPE_TF) * total_slopes * self.final_shift_factor
        offsets_at_outp_scale += tf.cast(output_zero_point, offsets_at_outp_scale.dtype)

        # Switch to the correct (intermediate) scale
        offset_candidates = offsets_at_outp_scale / self.final_shift_factor

        # (!) Offset quantization
        real_offsets = self.weight_lossy_elements.offset(offset_candidates, training=training)
        # (!) thresholds quantization
        acc_thresholds = self.weight_lossy_elements.thresholds(acc_thresholds, training=training)

        return offset_candidates, real_offsets, acc_thresholds

    def get_offset_needed_shift(self):
        clip_bits = self.weight_lossy_elements.offset.bits
        max_num = 2 ** (clip_bits - 1)
        min_num = -max_num - 1

        max_current_offset = np.max(self.offset_candidates)
        min_current_offset = np.min(self.offset_candidates)

        ratio_offset_max = max_current_offset / max_num
        ratio_offset_min = min_current_offset / min_num

        max_ratio = np.max([ratio_offset_max, ratio_offset_min])
        if max_ratio != 0:
            needed_shift = np.ceil(np.log2(max_ratio))
        else:
            needed_shift = 0
        return needed_shift

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True

    @staticmethod
    def limit_bit_rounding(x, bits_to_round):
        """
        Perform Banker's rounding over limit number of bits. This means that the
        rounding is exact only up to round_by_bits precision.
        """
        return tf.math.round(tf.math.sign(x) * tf.math.floor(tf.math.abs(x) * 2**bits_to_round) / 2**bits_to_round)

    def _compute_output_shape(self, input_shape):
        return input_shape

    def call_hw_sim(self, inputs, training=False, **kwargs):
        """
        Simulating APU HW - piecewise-linear approximation
        """
        inp = inputs[0]
        num_pieces = len(self.slopes)
        total_slopes = (
            self.weight_lossy_elements.mantissa(self.mantissas_candidate, training=training) * self.exponent_factors
        )

        slopes = tf.zeros_like(inp)
        offsets = tf.zeros_like(inp)

        for i in range(num_pieces):
            condition = tf.logical_and(
                tf.math.greater_equal(inp, self.acc_thresholds[i], name=f"threshold_ge_{i}"),
                tf.math.less(inp, self.acc_thresholds[i + 1], name=f"threshold_lt_{i}"),
                name=f"pwla_threshold_{i}",
            )
            current_slope = total_slopes[i]
            current_offset = self.real_offsets[i]
            slopes = tf.where(condition, current_slope, slopes)
            offsets = tf.where(condition, current_offset, offsets)
        post_slope = tf.multiply(inp, slopes, "slope_mul")
        # there should be bankers rounding here (after the exponent, which is fused with the slope)
        # applying it here is more expensive than applying it after the offset, and results should be the same
        post_clip1 = self.weight_lossy_elements.clip1(post_slope, training=training)
        post_offset = post_clip1 + offsets
        # NOTE: clip2 was removed because the quant applies similar clip and rounding, and should be more efficient.
        # but the banker's rounding might results slightly different results after the offset
        post_act = tf.multiply(post_offset, self.final_shift_factor, "out_shift_mul")
        return post_act

    def _get_mask(self, inputs):
        """
        this function return a mask (boolean) tensor of the input shape expended with a
        hot vector of the size of the number of acc_thresholds-1 that indicated in which of the linear pieces it fulls in .

        for example  for an input of shape [n,w,h,channels], and  size_a as the number of acc_thresholds,
        the function will return a boolean array of shape [n,w,h,size_a-1, channels].

        Example:
        inputs = [11350., 10296.,  1334.]  -
                shape=(3)
        acc_thresholds = array([[-32767., -32767., -32767.],
                                [  1886.,   1886.,   1886.],
                                [  2637.,   2637.,   2637.],
                                [  3608.,   3608.,   3608.],
                                [  4862.,   4862.,   4862.],
                                [  6463.,   6463.,   6463.],
                                [  8440.,   8440.,   8440.],
                                [ 10873.,  10873.,  10873.],
                                [ 32767.,  32767.,  32767.]]
                                shape - (9, 3)

         return_value= [[False, False,  True],
                        [False, False, False],
                        [False, False, False],
                        [False, False, False],
                        [False, False, False],
                        [False, False, False],
                        [False,  True, False],
                        [True, False, False]
                        shape=(8, 3)

        """
        inputs_expanded = tf.expand_dims(tf.cast(inputs, self.FLOAT_TYPE_TF), -2)
        acc_thresholds = tf.cast(self.acc_thresholds, self.FLOAT_TYPE_TF)
        acc_thresholds_down = acc_thresholds[:-1, :]
        acc_thresholds_up = acc_thresholds[1:, :]
        return tf.logical_and(
            tf.math.greater_equal(inputs_expanded, acc_thresholds_down),
            tf.math.greater(acc_thresholds_up, inputs_expanded),
        )

    def call_bit_exact(self, inputs, **kwargs):
        """
        Simulating APU HW - piecewise-linear approximation
        """
        inp = tf.cast(inputs[0], self.INT_TYPE_TF)

        mask = self._get_mask(inp)
        mask = tf.cast(mask, self.INT_TYPE_TF)

        # mantisas
        mantisa = tf.cast(self.weight_lossy_elements.mantissa(self.mantissas_candidate), self.INT_TYPE_TF)
        mantisa_to_hw = tf.reduce_sum(mask * mantisa, axis=-2)

        # shifts
        shift_to_hw = -1 * (tf.cast(self.exponents, self.INT_TYPE_TF) + tf.cast(self.shift_data, self.INT_TYPE_TF))
        shift_to_hw = tf.reduce_sum(mask * shift_to_hw, axis=-2)

        # offsets
        real_offsets = tf.cast(self.real_offsets, self.INT_TYPE_TF)
        offsets = tf.reduce_sum(mask * real_offsets, axis=-2)

        # bankers_round 1
        mant_mul = mantisa_to_hw * inp
        signed = self.weight_lossy_elements.mantissa.signed
        bankers_round1 = self.bankers_round_with_shift(mant_mul, shift_to_hw, POST_SHIFT_1_ROUNDING, signed=signed)
        post_clip1 = self.hw_simulation_by_lossy_element(bankers_round1, self.weight_lossy_elements.clip1)

        pre_clip2 = post_clip1 + offsets
        post_clip2 = self.hw_simulation_by_lossy_element(pre_clip2, self.weight_lossy_elements.clip2)

        bankers_round2 = self.bankers_round_with_shift(post_clip2, self.apu_final_shift, self.apu_final_shift)
        symmetric = self.output_lossy_element.signed
        output = self.hw_simulation_by_lossy_element(bankers_round2, self.output_lossy_element, symmetric)
        return output

    def call_native(self, inputs, **kwargs):
        result = self.act_func(inputs[0])
        # The _clip_range code is related to INTERNAL/EXPERIMENTAL code in clip_activation_stats
        self._logger.verbose(f"{self.full_name}: clipped native")
        if self._clip_range is not None:
            result = tf.clip_by_value(result, *self._clip_range)
        return result

    def get_input_scale_candidate(self):
        min_input_scale = self.calc_min_input_scale()
        target_scale_factors = min_input_scale / self.output_scale
        target_scale_factors = np.reshape(target_scale_factors, (-1, self.quantization_groups_num))
        target_scale_factors_per_group = np.max(target_scale_factors, axis=0)
        target_scale_factors_per_ch = np.repeat(target_scale_factors_per_group, self.base_group_size)
        input_scale_candidate = self.output_scale * target_scale_factors_per_ch
        return input_scale_candidate

    def create_act_name_and_func(self, activation: Union[ActivationType, str, callable]):
        """
        create an activation atomic op with specified type
        Args:
            activation: if str - activation type as listed in ActivationType enum
                     if ActivationType - return the ActivationType as specified in activation_gen_by_type
                     if callable - returns the callable function

        Returns
            the appropriate activation function

        # TODO: should be triggered earlier, but located here because this class supports callable activation
        # TODO Hailo PL representation
        # TODO add support for functools.partial activation functions

        """
        if callable(activation):
            return activation, "callable"

        if isinstance(activation, str):
            activation = ActivationType(activation)

        def biased_delta(x, **kwargs):
            """
            Realization of biased_delta activation
            """
            # TODO: Set a higher eps if we find that we need to increase the tolerance of the operation.
            eps = 0.0
            return self.act_native_params["activation_delta_bias"] * tf.sign(tf.maximum(tf.abs(x) - eps, 0.0))

        def delta(x, **kwargs):
            # TODO: Set a higher eps if we find that we need to increase the tolerance of the operation.
            eps = 0.0
            return 1 - tf.sign(tf.maximum(tf.abs(x) - eps, 0.0))

        def inv_pos(x, **kwargs):
            x = x * self.act_native_params["inverse_act_factor"]
            return hailo_reciprocal(x)

        def minus_inv_pos(x, **kwargs):
            return hailo_reciprocal(-x)

        def relu1(x, **kwargs):
            """
            Realization of relu1 activation
            """
            return tf.nn.relu6(x * 6.0) / 6.0

        def less(x, **kwargs):
            return tf.cast(tf.math.less(x, self.act_native_params["y"]), self.FLOAT_TYPE_TF)

        def greater(x, **kwargs):
            return tf.cast(tf.math.greater(x, self.act_native_params["y"]), self.FLOAT_TYPE_TF)

        def native_hardswish_activation(x, **kwargs):
            return x * tf.nn.relu6(x + 3) / 6

        def native_swish_activation(x, **kwargs):
            return x * tf.nn.sigmoid(self.act_native_params["swish_beta"] * x)

        def native_hardsigmoid_activation(x, **kwargs):
            return tf.math.maximum(
                tf.constant([0.0]),
                tf.math.minimum(
                    self.act_native_params["hardsigmoid_alpha"] * x + self.act_native_params["hardsigmoid_beta"],
                    tf.constant([1.0]),
                ),
            )

        def native_clip_activation(x, **kwargs):
            return tf.math.maximum(
                tf.cast(self.act_native_params["clip_min"], dtype=tf.float32),
                tf.math.minimum(x, tf.cast(self.act_native_params["clip_max"], dtype=tf.float32)),
            )

        def leaky_relu(x, **kwargs):
            leaky_op = tf.keras.layers.LeakyReLU(**self.act_native_params)
            return leaky_op(x)

        def inv_sqrt(x, **kwargs):
            x = x * self.act_native_params["inverse_act_factor"]
            x = tf.cast(x, self.FLOAT_TYPE_TF)
            eps = 1e-6
            return hailo_reciprocal(tf.math.sqrt(tf.maximum(x, eps)))

        def prelu(x, **kwargs):
            pos = tf.nn.relu(x)
            neg = -self.act_native_params["prelu_slope"] * tf.nn.relu(-x)
            return pos + neg

        def threshold(x, **kwargs):
            op = tf.keras.layers.ThresholdedReLU(**self.act_native_params)
            return op(x)

        def mish(x, **kwargs):
            x = tf.cast(x, self.FLOAT_TYPE_TF)
            x = x * tf.math.tanh(tf.math.softplus(x))
            return x

        def pow(x, **kwargs):
            pow_exponent = self.act_native_params["pow_exponent"]
            x = tf.pow(x, pow_exponent)
            return x

        def hdr_compression(x):
            """
            This function compresses the input of 20-bit to 16-bit.
            Decompression from 16-bit to 20-bit is applied by the ISP of H-15.
            """
            y = tf.zeros_like(x)
            y = tf.where(x < 2**14, x / 2, y)
            y = tf.where(tf.logical_and(x >= 2**14, x < 2**15), x / 4 + 4096, y)
            y = tf.where(tf.logical_and(x >= 2**15, x < 2**18), x / 8 + 8192, y)
            y = tf.where(x >= 2**18, x / 32 + 32768, y)
            return y

        def relu_positive_square(x):
            return tf.math.square(tf.nn.relu(x))

        def pwl(x, thresholds=None, offsets=None, slopes=None):
            thresholds = (
                np.array(self.act_native_params["thresholds"]) if thresholds is None else tf.cast(thresholds, x.dtype)
            )
            offsets = np.array(self.act_native_params["offsets"]) if offsets is None else tf.cast(offsets, x.dtype)
            slopes = np.array(self.act_native_params["slopes"]) if slopes is None else tf.cast(slopes, x.dtype)
            inputs_expanded = tf.expand_dims(x, -1)
            mask = tf.concat(
                [
                    tf.math.less(inputs_expanded, thresholds[0]),
                    tf.logical_and(
                        tf.math.greater_equal(inputs_expanded, thresholds[:-1]),
                        tf.math.less(inputs_expanded, thresholds[1:]),
                    ),
                    tf.math.greater_equal(inputs_expanded, thresholds[-1]),
                ],
                axis=-1,
            )
            mask = tf.cast(mask, x.dtype)
            slopes = tf.reduce_sum(mask * slopes, axis=-1)
            offsets = tf.reduce_sum(mask * offsets, axis=-1)
            return x * slopes + offsets

        def exp_decompose(x):
            mask = self.act_native_params.get("mask", None)
            if mask is None:
                return 6.0 * (2.0 ** tf.floor(tf.math.log(tf.abs(x)) / tf.math.log(2.0)))
            else:
                larger_mask = np.maximum(mask, -mask[::-1])[len(mask) // 2 :]
                threshold = mask[1:-1]
                offsets = 3 * tf.concat([larger_mask[::-1], larger_mask[1:]], axis=-1)
                slopes = 0.0
                return pwl(x, thresholds=threshold, offsets=offsets, slopes=slopes)

        def shift(x):
            mask = self.act_native_params.get("mask", None)
            if mask is None:
                return 2.0 * x / (2.0 ** tf.floor(tf.math.log(tf.abs(x)) / tf.math.log(2.0))) - 3.0
            else:
                larger_mask = np.maximum(mask, -mask[::-1])[len(mask) // 2 :]
                buffer = larger_mask[-1] / 2**15
                threshold = 4 * larger_mask[:-1] - buffer
                offsets = -3.0
                slopes = 1.0 / larger_mask
                return pwl(x, thresholds=threshold, offsets=offsets, slopes=slopes)

        activation_gen_by_type = {
            ActivationType.RELU: tf.nn.relu,
            ActivationType.RELU6: tf.nn.relu6,
            ActivationType.RELU1: relu1,
            ActivationType.SIGMOID: tf.nn.sigmoid,
            ActivationType.LINEAR: tf.identity,
            ActivationType.ELU: tf.nn.elu,
            ActivationType.EXP: tf.keras.activations.exponential,
            ActivationType.TANH: tf.keras.activations.tanh,
            ActivationType.SOFTPLUS: tf.keras.activations.softplus,
            ActivationType.SILU: tf.nn.silu,
            ActivationType.GELU: tf.nn.gelu,
            ActivationType.MISH: mish,
            ActivationType.INV_POS: inv_pos,
            ActivationType.MINUS_INV_POS: minus_inv_pos,
            ActivationType.BIASED_DELTA: biased_delta,
            ActivationType.SQRT: tf.sqrt,
            ActivationType.HARDSWISH: native_hardswish_activation,
            ActivationType.SWISH: native_swish_activation,
            ActivationType.LESS: less,
            ActivationType.LOG: tf.math.log,
            ActivationType.HARDSIGMOID: native_hardsigmoid_activation,
            ActivationType.CLIP: native_clip_activation,
            ActivationType.INV_SQRT: inv_sqrt,
            ActivationType.LEAKY: leaky_relu,
            ActivationType.PRELU: prelu,
            ActivationType.THRESHOLD: threshold,
            ActivationType.SOFTSIGN: tf.keras.activations.softsign,
            ActivationType.DELTA: delta,
            ActivationType.GREATER: greater,
            ActivationType.POW: pow,
            ActivationType.HDR_COMPRESSION: hdr_compression,
            ActivationType.RELU_POSITIVE_SQUARE: relu_positive_square,
            ActivationType.PWL: pwl,
            ActivationType.EXP_DECOMPOSE: exp_decompose,
            ActivationType.SHIFT: shift,
        }

        self.act_name = activation
        self.act_func = activation_gen_by_type[activation]
        return self.act_name, self.act_func

    def export_independent_params(self):
        return {
            "output_factors": np.array(self.output_factor_by_group, np.float32),
            # TODO: find more suitable place for these
            "thresholds": np.array(self.thresholds, np.float32),
            "offsets": np.array(self.offsets, np.float32),
            "slopes": np.array(self.slopes, np.float32),
            "input_shape": self.input_shape[1:],
        }

    def import_independent_params(self, params):
        self.output_factor_by_group = params["output_factors"]
        self.thresholds = params["thresholds"]
        self.offsets = params["offsets"]
        self.slopes = params["slopes"]
        self._input_shapes = [[None, *params["input_shape"]]]
        self.update_mantissa_exponent_decomposition()

    def export_quant_weights(self):
        self.check_exp_range()
        offset = 0
        x_points = split_to_quantize_groups(self.acc_thresholds, self.base_group_size, offset=offset)
        offsets = split_to_quantize_groups(self.real_offsets, self.base_group_size, offset=offset)
        exponent = split_to_quantize_groups(self.exponents, self.base_group_size, offset=offset)

        # This is under the assumption that enable_lossy was called.
        slopes_m = split_to_quantize_groups(self.mantissas_candidate, self.base_group_size, offset=offset)

        slopes_e = self.get_assigned_exponent(-exponent)
        shifter_bias_max_value = 2 ** (self.weight_lossy_elements.offset.bits - 1)
        size_splits = get_split_size(self.quantization_groups_num, self.base_group_size, self.num_of_channels)
        x_points = np.round(x_points)
        x_points = np.clip(
            x_points,
            self.weight_lossy_elements.thresholds.min_value,
            self.weight_lossy_elements.thresholds.max_value,
        )

        # we dont add to the x_points a point because it comes with (-inf, +inf)
        return {
            "quant_thresholds": x_points,
            "quant_offsets": np.round(np.float32(offsets)),
            "slopes_mantissas": np.float32(slopes_m),
            "slopes_exponents": np.round(np.float32(slopes_e)),
            "exponent_shift_bias": np.array(self.apu_exp_bias, np.float32),
            "output_shift_bias": np.array(self.apu_final_shift, np.float32),
            "max_offset_bias": np.array(shifter_bias_max_value, np.float32),
            "x_points_mask_max_value": np.float32(DEFAULT_X_POINTS_MAX_VALUE),
            "quantization_groups_size": np.array(size_splits),
        }

    def export_hw_params(self):
        params = self.export_quant_weights()
        quant_offsets = params["quant_offsets"]
        slopes_mantissas = params["slopes_mantissas"]
        slopes_exponents = params["slopes_exponents"]
        quant_thresholds = params["quant_thresholds"]

        slopes_mantissas = self.weight_lossy_elements.mantissa(slopes_mantissas)

        if self.act_name == ActivationType.RELU6:
            bits = self.output_lossy_element.bits
            # 6.01 is for overcoming numeric instability
            scale_th = 6.01 / (2**bits - 1)
            if self.output_scale[0] <= scale_th:
                # remove last piece
                quant_offsets = quant_offsets[:, :-1]
                slopes_mantissas = slopes_mantissas[:, :-1]
                slopes_exponents = slopes_exponents[:, :-1]
                quant_thresholds = np.concatenate([quant_thresholds[:, :-2], quant_thresholds[:, -1:]], axis=-1)

        quant_offsets = _add_first_and_end_points(quant_offsets)
        slopes_mantissas = _add_first_and_end_points(slopes_mantissas)
        slopes_exponents = _add_first_and_end_points(slopes_exponents)
        return {
            "output_stage/piecewise/x_points": quant_thresholds.astype(np.int32),
            "output_stage/piecewise/offsets": quant_offsets.astype(np.int32),
            "output_stage/piecewise/slopes_m": slopes_mantissas.astype(np.int16),
            "output_stage/piecewise/slopes_e": slopes_exponents.astype(np.uint8),
            "output_stage/piecewise/size_splits": params["quantization_groups_size"].astype(np.uint32),
        }

    def _apply_activation_fitting(self, optimization_target):
        self._logger.info(f"activation fitting started for {self.full_name}")
        if self.act_name in [ActivationType.INV_POS, ActivationType.MINUS_INV_POS, ActivationType.INV_SQRT]:
            num_of_samples = 3500
        else:
            num_of_samples = 350
        th, slopes, off = self.find_optimized_breaks(
            self.act_func, self.get_input_limvals(0), optimization_target, num_of_samples=num_of_samples
        )
        self.thresholds = th
        self.slopes = slopes
        self.offsets = off

    def _get_x_y(self, func, xrange_limvals, num_of_samples, histogram):
        rng = np.random.default_rng(89)
        x_min, x_max = xrange_limvals
        bins = np.linspace(x_min, x_max, len(histogram) + 1, dtype=np.float64)
        cdf = np.cumsum(histogram)
        cdf = cdf / cdf[-1]
        bin_midpoints = bins[:-1] + np.diff(bins) / 2
        values = np.linspace(0.001, 0.999, num_of_samples, dtype=np.float64)
        value_bins = np.searchsorted(cdf, values)
        diff_bin = np.diff(bins)[0]
        diff_rand = rng.uniform(-diff_bin / 2, diff_bin / 2, size=num_of_samples)
        x = np.sort(bin_midpoints[value_bins] + diff_rand)

        y = tf.cast(func(x), dtype=np.float64)
        return x, y

    def _get_min_max(self, func, xrange_limvals):
        x_min, x_max = xrange_limvals
        bits = self.output_lossy_element.bits
        max_y = func(x_max)
        halph_bin = max_y / 2 ** (bits + 1)
        x_min = np.max([-(1 / halph_bin**2), x_min])
        return x_min, x_max

    def find_optimized_breaks(self, func, xrange_limvals, optimization_target, n_segments=8, num_of_samples=350):
        x_min, x_max = xrange_limvals
        if self.act_name == ActivationType.EXP:
            self._logger.debug(f"the limvals before ({x_min},{x_max})")
            bits = self.output_lossy_element.bits
            smallest_value = np.log(1 / 2 ** (bits + 1))  # x_max - 6.23
            x_min = np.maximum(x_max + smallest_value, x_min)
        elif self.act_name == ActivationType.LOG:
            x_min = 1e-3 if x_min <= 0 else x_min
        elif self.act_name == ActivationType.INV_SQRT and self.act_native_params["inverse_act_factor"] < 0:
            x_min, x_max = self._get_min_max(func, xrange_limvals)

        if (
            self.act_name in [ActivationType.INV_POS, ActivationType.INV_SQRT]
            and self.act_native_params["inverse_act_factor"] < 0
        ) or self.act_name == ActivationType.MINUS_INV_POS:
            histogram = self.get_input_stats(0).dynamic_histogram
            if histogram is None:
                raise AccelerasInitializationError(
                    f"Activation fitting was enabled but there is no histogram on {self.act_name} in layer {self.full_name}",
                )
            x, y = self._get_x_y(func, (x_min, x_max), num_of_samples, histogram=histogram)
            self._logger.debug(f"the limvals before ({x_min},{x_max})")
            x_min = np.min(x)
            x_max = np.max(x)
            self._logger.debug(f"the limvals after ({x_min},{x_max})")
        else:
            x = np.linspace(x_min, x_max, num_of_samples)
            y = func(x)

        x_points, my_pwlf, x, y = self._get_slopes(x, y, n_segments, func, optimization_target)

        slopes = my_pwlf.calc_slopes()
        offsets = my_pwlf.intercepts
        y_points = my_pwlf.predict(x_points)

        # an edge case where the slopes are negative and the algorithm before ddint fix them
        if (slopes < 0).any():
            x_points, slopes, offsets, y_points = self._change_slopes(x_points, slopes, offsets, my_pwlf)

        y_points_real = func(x_points)
        self._logger.debug(f"the limvals are ({x_min},{x_max}) .")
        self._logger.debug(f"x = {x_points}")
        self._logger.debug(f"y = {y_points}")
        self._logger.debug(f"slopes = {slopes}")
        self._logger.debug(f"error y_points predicted and y_points_real = {y_points_real - y_points}")

        if self._ignore_hw_limitation_assertion == IgnoreHwLimitationAssertionPolicy.enabled:
            slopes = np.maximum(slopes, 0)
        if (np.array(slopes) < 0).any():
            raise AccelerasInitializationError(f"negative slopes are needed in {self.full_name} but not supported")

        return x_points[1:-1], slopes, offsets

    def _get_slopes(self, x, y, n_segments, func, optimization_target):
        # optimize the piecewise linear function

        start_time = time.time()

        # Fit the piecewise linear function in a loop
        j = 0
        fitting = True

        ignore = self._ignore_hw_limitation_assertion == IgnoreHwLimitationAssertionPolicy.enabled
        while fitting and j < 3:
            my_pwlf = pwlf.PiecewiseLinFit(x, y)
            x_points = my_pwlf.fit(n_segments=n_segments, seed=1)

            # Calculate the slopes
            slopes = my_pwlf.calc_slopes()

            # If all slopes are positive, break the loop
            if np.all(slopes > 0) and optimization_target == OptimizationTarget.SAGE or ignore:
                fitting = False
            else:
                # incase the slopes are negative we will add a point in the middle of the negative slopes peice and fit again
                for i in np.where(slopes < 0):
                    diff = x_points[i + 1] - x_points[i]
                    new_x = x_points[i] + diff / 2
                    x = np.insert(x, 0, new_x)
                x = np.sort(x)
                y = tf.cast(func(x), dtype=np.float64)
            self._logger.debug(f"{j}, {slopes}")
            j += 1

        elapsed_time = time.time() - start_time

        self._logger.verbose(f"activation fitting ended in {elapsed_time:.2f} seconds type {self.act_name}")
        return x_points, my_pwlf, x, y

    def get_encoding_flow(self):
        # TODO: when implmenting quantization groups- we need to remove this function. SDK-42505
        if self.quantization_groups_num > 1 and not self.encoding_const:
            self._logger.warning(
                f"get_encoding_flow for activation op {self.full_name} with multiple groups is't supported yet.",
            )
            self.encoding_const = True
        return super().get_encoding_flow()

    def define_encodings(self, flow):
        super().define_encodings(flow)

        flow.add_encoding(
            f"{self.full_name}/output_factor_by_group:0",
            EncodingType.Scale,
            scalar=True,
            shape=self.output_factor_by_group.shape,
            initializer=TensorInitializer(self.output_factor_by_group),
        )
        flow.add_encoding(
            f"{self.full_name}/output_factor_by_group_clipped:0",
            EncodingType.Scale,
            scalar=True,
            shape=self.output_factor_by_group.shape,
        )

        flow.add_encoding(
            f"{self.full_name}/exponent_factors:0",
            EncodingType.Scale,
            scalar=False,
            shape=self.exponent_factors_by_group.shape,
        )
        flow.add_encoding(
            f"{self.full_name}/mantissas_candidate:0",
            EncodingType.Scale,
            scalar=False,
            shape=self.mantissas_candidate_by_group.shape,
        )
        flow.add_encoding(
            f"{self.full_name}/exponents:0",
            EncodingType.Scale,
            scalar=False,
            shape=self.exponents_by_group.shape,
        )

        flow.add_encoding(
            f"{self.full_name}/offset_candidates:0",
            EncodingType.ZeroPoint,
            scalar=False,
            shape=self.offset_candidates.shape,
        )
        flow.add_encoding(
            f"{self.full_name}/quant_offsets:0",
            EncodingType.ZeroPoint,
            scalar=False,
            shape=self.real_offsets.shape,
        )
        flow.add_encoding(
            f"{self.full_name}/quant_thresholds:0",
            EncodingType.ZeroPoint,
            scalar=False,
            shape=self.acc_thresholds.shape,
        )

        if not self.homogeneous:
            flow.nodes[f"{self.full_name}/output_scale:0"]["encoding"].scalar = True

    def define_constraints(self, enc):
        super().define_constraints(enc)

        # TODO: replace slops_dtype this with self.FLOAT_TYPE_NP after
        # https://bitbucket.org/hailotech/phase2-sdk/pull-requests/21700 is merged
        slops_dtype = self.FLOAT_TYPE_NP

        # clip negative exponents
        non_zero_slopes = np.array(self.slopes, dtype=slops_dtype)[np.nonzero(self.slopes)]
        if non_zero_slopes.size == 0:
            clip_value_max = np.finfo(slops_dtype).max
        else:
            clip_value_max = (
                self.final_shift_factor
                * (2 ** (APU_MANTISSA_BITS + self.shift_data - self.apu_exp_bias))
                / np.max(non_zero_slopes)
            )
        enc.callback(
            f"{self.full_name}/output_factor_by_group_clipped:0",
            f"{self.full_name}/output_factor_by_group:0",
            tf.clip_by_value,
            callback_name="tf.clip_by_value",
            clip_value_min=0.0,
            clip_value_max=clip_value_max,
        )

        # compute output_scale
        enc.callback(
            enc.dummy("vector_output_factors"),
            f"{self.full_name}/output_factor_by_group_clipped:0",
            self._group_to_vector,
            callback_name="tf.repeat",
            outs_shape=self.output_factors.shape,
        )
        enc.mul(
            f"{self.full_name}/input_scale:0",
            f"{self.full_name}/output_scale:0",
            enc.dummy("vector_output_factors"),
            inverse=True,
        )

        # compute slopes encodings
        enc.callback(
            [
                f"{self.full_name}/exponent_factors:0",
                f"{self.full_name}/mantissas_candidate:0",
                f"{self.full_name}/exponents:0",
            ],
            f"{self.full_name}/output_factor_by_group_clipped:0",
            partial(self._get_mantissa_exponent_decomposition, np.array(self.slopes, dtype=slops_dtype)),
            callback_name="decompose_slopes",
            outs_scalar=[False, False, False],
        )

        enc.callback(
            enc.dummy("vector_exponent_factors"),
            f"{self.full_name}/exponent_factors:0",
            self._group_to_vector,
            callback_name="tf.repeat",
            outs_shape=self.exponent_factors.shape,
        )
        enc.callback(
            enc.dummy("vector_mantissas_candidate"),
            f"{self.full_name}/mantissas_candidate:0",
            self._group_to_vector,
            callback_name="tf.repeat",
            outs_shape=self.mantissas_candidate.shape,
        )

        enc.lossy_element(
            enc.dummy("mantissas"),
            enc.dummy("vector_mantissas_candidate"),
            self.weight_lossy_elements.mantissa,
        )
        enc.mul(enc.dummy("total_slopes"), enc.dummy("mantissas"), enc.dummy("vector_exponent_factors"))

        # compute offsets and thresholds encodings
        enc.callback(
            [
                f"{self.full_name}/offset_candidates:0",
                f"{self.full_name}/quant_offsets:0",
                f"{self.full_name}/quant_thresholds:0",
            ],
            [
                enc.dummy("total_slopes"),
                f"{self.full_name}/input_scale:0",
                f"{self.full_name}/input_zero_point:0",
                f"{self.full_name}/output_scale:0",
                f"{self.full_name}/output_zero_point:0",
            ],
            self._get_pieces_encoding,
            callback_name="get_pieces_encoding",
            outs_scalar=[False, False, False],
        )

    def define_const_constraints(self, enc):
        super().define_const_constraints(enc)
        enc.identity(f"{self.full_name}/output_factor_by_group:0", self.output_factor_by_group)
        enc.identity(f"{self.full_name}/output_factor_by_group_clipped:0", self.output_factor_by_group)

        enc.identity(f"{self.full_name}/exponent_factors:0", self.exponent_factors_by_group)
        enc.identity(f"{self.full_name}/mantissas_candidate:0", self.mantissas_candidate_by_group)
        enc.identity(f"{self.full_name}/exponents:0", self.exponents_by_group)

        enc.identity(f"{self.full_name}/offset_candidates:0", self.offset_candidates)
        enc.identity(f"{self.full_name}/quant_offsets:0", self.real_offsets)
        enc.identity(f"{self.full_name}/quant_thresholds:0", self.acc_thresholds)

    def update_encoding(self, encodings):
        super().update_encoding(encodings)
        self.output_factor_by_group = encodings[f"{self.full_name}/output_factor_by_group_clipped:0"]

        self.exponent_factors_by_group = encodings[f"{self.full_name}/exponent_factors:0"]
        self.mantissas_candidate_by_group = encodings[f"{self.full_name}/mantissas_candidate:0"]
        self.exponents_by_group = encodings[f"{self.full_name}/exponents:0"]

        self.offset_candidates = encodings[f"{self.full_name}/offset_candidates:0"]
        self.real_offsets = encodings[f"{self.full_name}/quant_offsets:0"]
        self.acc_thresholds = encodings[f"{self.full_name}/quant_thresholds:0"]

    def _change_slopes(self, x_points, slopes, offsets, my_pwlf):
        y_points = my_pwlf.predict(x_points)

        x_points_new = [x_points[0]]
        slopes_new = []
        offsets_new = []
        i = 0
        while i < len(slopes):
            m = slopes[i]
            if m < 0:
                start_seg = y_points[i]
                j = i + 1
                while j < len(slopes) and y_points[j + 1] < start_seg:
                    j += 1
                if j == len(slopes):
                    x_points_new.append((y_points[i] - offsets[j - 1]) / slopes[j - 1])
                else:
                    x_points_new.append((y_points[i] - offsets[j]) / slopes[j])
                offsets_new.append(y_points[i])
                slopes_new.append(0)
                i = j
            else:
                x_points_new.append(x_points[i + 1])
                offsets_new.append(offsets[i])
                slopes_new.append(m)
                i += 1
        x_points = np.array(x_points_new)
        slopes = np.array(slopes_new)
        offsets = np.array(offsets_new)
        y_points = my_pwlf.predict(x_points)
        return x_points, slopes, offsets, y_points

    def get_output_limvals(self, output_index: int):
        if f"outputs_{output_index}" in self.stats_managers:
            lim_vals = super().get_output_limvals(output_index)
        else:
            lim_vals = np.array(self.act_func(self.get_input_limvals(0)))
        return lim_vals

    def harmless_clipping(
        self,
        bins_clip: float = 0.25,
        samples_per_bin=32,
    ) -> Tuple[float, float]:
        """
        This method returns the range of values that if they are clip it will not affect the accuracy of the model
        becuase they are mapped to bin 0 and bin max., this clips should be use for bounded activations.


        Args:
            bins_clip (float, optional): how much should be clip to fit the bin value
            samples_per_bin (int, optional): Sample size for the cdf. Defaults to 32.

        Returns:
            Lin values min and max t hat can be clipped without changing the lossy output

        """
        output_scale = self.output_scale[0]  # this migh be problematicc with vector scales
        r_min, r_max = self.get_input_limvals(0)
        total_samples = 2**self.input_lossy_element.bits * samples_per_bin
        sample_points = np.linspace(r_min, r_max, total_samples)
        outputs = self.act_func(sample_points)

        lower_index = np.where(outputs.numpy() < np.min(outputs) + output_scale * bins_clip)[0][-1]
        upper_index = np.where(outputs.numpy() > np.max(outputs) - output_scale * bins_clip)[0][0]
        return sample_points[lower_index], sample_points[upper_index]

    def _crop_offsets(self):
        out_min, out_max = self.get_output_limvals(0)
        offsets = np.array(self.offsets)[np.where(self.offsets < out_max * self.REMOVE_OFFSET_FACTOR)]
        remove = len(np.array(self.offsets)) - len(offsets)
        self.slopes = np.array(self.slopes)[:-remove]
        self.thresholds = np.array(self.thresholds)[:-remove]
        self.offsets = offsets

    def import_flow_state(self, atomic_state: AtomicOpState):
        super().import_flow_state(atomic_state)
        if "native_act" in atomic_state.aops_dict_kwgs:
            self._set_fully_native_from_flow_state(atomic_state)

    def _set_fully_native_from_flow_state(self, atomic_state):
        native_action = atomic_state.aops_dict_kwgs.get("native_act")

        if native_action == "enable":
            self.fully_native = True
        elif native_action == "disable":
            self.fully_native = False
        else:
            raise ValueError(f"Unknown native_act value: {native_action}")
