#!/usr/bin/env python

import time

import numpy as np
import pwlf
from past.utils import old_div

from hailo_model_optimization.acceleras.utils.acceleras_definitions import ActivationFitPolicy
from hailo_sdk_client.sdk_backend.sdk_backend_exceptions import (
    BackendNotImplementedError,
    BackendQuantizationException,
    SDKBackendException,
)
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType
from hailo_sdk_common.logger.logger import default_logger

ACT_COMP_MAX_NO = 8
ACT_COMP_MAX_VALUE = 0x7FFF
MAX_ALLOWED_QUANTIZATION_GROUPS = 4
MAX_ALLOWED_PIECE_NUMBER = 2


class BackendNegativeSlopesException(SDKBackendException):
    def __init__(self, message, abs_min_slope):
        super().__init__(message)
        self._abs_min_slope = abs_min_slope

    @property
    def abs_min_slope(self):
        return self._abs_min_slope


class BackendOffsetsException(SDKBackendException):
    def __init__(self, message, bit_loss):
        super().__init__(message)
        self._output_scale_factor = pow(2, bit_loss)
        self._bit_loss = bit_loss

    @property
    def output_scale_factor(self):
        return self._output_scale_factor

    @property
    def bit_loss(self):
        return self._bit_loss


UNSUPPORTED_ACTIVATION_FIT = [
    ActivationType.leaky,
    ActivationType.elu,
    ActivationType.relu,
    ActivationType.linear,
    ActivationType.relu6,
    ActivationType.tanh,
    ActivationType.threshold,
    ActivationType.biased_delta,
    ActivationType.relu1,
]


class PiecewiseActivator:
    def __init__(
        self,
        y_points,
        x_points,
        inter_layer_precision_mode,
        saturations=None,
        should_add_delta_point=False,
        continuous=True,
    ):
        self._logger = default_logger()
        self.y_points = y_points
        self.x_points = x_points
        self.saturations = saturations
        self.inter_layer_precision_mode = inter_layer_precision_mode
        self._should_add_delta_point = should_add_delta_point
        self._continuous = continuous
        self._x_points_diff = np.diff(self.x_points)
        self._identical_x_points = np.sum(self._x_points_diff == 0)
        self._compute_slopes_and_offsets()

    @property
    def continuous(self):
        return self._continuous

    def _remove_infinite_slopes(self, slopes, offsets):
        return slopes[self._x_points_diff != 0], offsets[self._x_points_diff != 0]

    def _remove_duplicate_x_points(self, x_points):
        return np.concatenate((np.array([x_points[0]]), np.array(x_points[1:])[self._x_points_diff != 0]))

    def _add_delta_point_and_piece(self, offset):
        self.x_points = np.insert(self.x_points, 2, self.x_points[1] + 1)
        self.slopes_e = np.insert(self.slopes_e, 2, 0)
        self.slopes_m = np.insert(self.slopes_m, 2, 0)
        self.offsets = np.insert(self.offsets, 2, offset)

    def _compute_slopes_and_offsets(self):
        if np.sum(np.diff(self.x_points) == 0) != self._identical_x_points:
            raise BackendQuantizationException(
                "two unexpected identical x points, cannot compute slopes. Probably layer inputs are too large",
            )
        x_points = self.x_points
        if self.inter_layer_precision_mode.is_mode(16, 16):
            x_points = [x * (2**8) for x in x_points]
        slopes = old_div(np.diff(self.y_points), np.diff(x_points))
        offsets = self.y_points[:-1] - slopes * x_points[:-1]
        if self._identical_x_points:
            slopes, offsets = self._remove_infinite_slopes(slopes, offsets)
        if self.saturations:
            saturation_slopes = [0, 0]
            saturation_offsets = (
                [
                    self.y_points[0],
                    self.y_points[-1],
                ]
                if self.saturations == "implicit"
                else self.saturations
            )
        else:
            saturation_slopes = [slopes[0], slopes[-1]]
            saturation_offsets = [offsets[0], offsets[-1]]
        self.slopes = np.concatenate((saturation_slopes[:1], slopes, saturation_slopes[-1:]))
        self.offsets = np.concatenate((saturation_offsets[:1], offsets, saturation_offsets[-1:]))

    def _set_scales(
        self,
        accumulator_scale,
        qp_out,
        beta,
        mantissa_bits,
        exp_bits,
        zp_compensation,
        shift,
        is_apu_2s_complement,
        signed_output=False,
        pre_act_limvals=None,
    ):
        self.x_points = list(old_div(self.x_points, accumulator_scale))
        self.y_points = [old_div(y * 2**beta, qp_out.scale) for y in self.y_points]
        self.saturations = (
            self.saturations
            if self.saturations in (None, "implicit")
            else [old_div(y * 2**beta, qp_out.scale) for y in self.saturations]
        )
        self._compute_slopes_and_offsets()
        zp_out = qp_out.zero_point if (not signed_output) else 0
        self.offsets = self.offsets - self.slopes * zp_compensation + (2.0**beta) * zp_out
        self.offsets = (np.round(self.offsets)).astype(int)
        quant_pre_act_limvals = np.round(pre_act_limvals / accumulator_scale)
        self._reduce_slope_precision(mantissa_bits, exp_bits, shift, signed_output, quant_pre_act_limvals, beta)

        self.x_points = [int(np.round(x + zp_compensation)) >> shift for x in self.x_points]
        max_value = self.inter_layer_precision_mode.shifter_bias_max_value
        min_value = -max_value - int(is_apu_2s_complement)
        a_max = self.inter_layer_precision_mode.activation_computation_max_value
        a_min = -a_max - int(is_apu_2s_complement)

        if np.any((self.offsets < min_value) | (self.offsets > max_value)):
            errmsg = f"One of the offsets value is not in range {[min_value, max_value]} {self.offsets} before clip"
            bit_loss = np.ceil(np.log2(np.max(np.abs(self.offsets)) / np.abs(max_value)))
            raise BackendOffsetsException(errmsg, bit_loss)

        # clip to 16 bits after computing slopes
        self.x_points = [int(num) for num in np.clip(self.x_points, a_min=a_min, a_max=a_max)]
        self.x_points = self._remove_duplicate_x_points(self.x_points)
        if self._should_add_delta_point:
            self._add_delta_point_and_piece((2.0**beta) * zp_out)

    def _reduce_slope_precision(self, mantissa_bits, exp_bits, shift, signed_output, pre_act_limvals, beta):
        if signed_output and pre_act_limvals is None:
            raise BackendNotImplementedError(
                "Pre activations limvals are required when 'activations as weights' is enabled",
            )
        m, e = np.frexp(self.slopes)
        self.slopes_m = np.round(m * 2**mantissa_bits)
        self.slopes_e = mantissa_bits - e - self.inter_layer_precision_mode.ebias

        # fix edge case where mantissa is 2**mantissa_bits.
        ones = np.ones_like(self.slopes_m, dtype=int)
        self.slopes_e = self.slopes_e - np.where(self.slopes_m == (2**mantissa_bits), ones, ones * 0)
        self.slopes_m = old_div(self.slopes_m, np.where(self.slopes_m == (2**mantissa_bits), ones * 2, ones))

        self.slopes_e = self.slopes_e - shift
        dummy = 1
        # fix edge case where slopes are zero and hence mantissa is 2**mantissa_bits.
        self.slopes_e = np.where(self.slopes != 0, self.slopes_e, dummy)
        non_zero_mantissa_slopes = self.slopes_e[np.nonzero(self.slopes_m)]
        if not np.all(non_zero_mantissa_slopes >= 0):
            errmsg = (
                f"Got negative value for piecewise calculation, slope exponent: {self.slopes_e}, shift: {shift}. "
                f"The range diff between the accumulator and the output cannot be represented."
            )
            raise BackendNegativeSlopesException(errmsg, np.abs(np.min(non_zero_mantissa_slopes)))
        self.slopes_e = [slopes_e if (slopes_e > 0) else 0 for slopes_e in self.slopes_e]

        # fix edge cases where exponent is too big for HW
        is_exp_too_big = self.slopes_e >= ones * (2**exp_bits)
        if any(is_exp_too_big):
            self._logger.debug(
                "Activation slope exponent overflow has occurred, trimming to the max value and fixing mantissa...",
            )
            mantissa_divider = (ones * 2).astype(float) ** (self.slopes_e - ones * (2**exp_bits - 1))
            # TODO: don't round the mantissa twice
            mantissa_not_zero_before = np.not_equal(self.slopes_m, 0)
            self.slopes_m = np.round(old_div(self.slopes_m, np.where(is_exp_too_big, mantissa_divider, ones)))
            if np.any(np.logical_and(mantissa_not_zero_before, self.slopes_m == 0)):
                self._logger.debug("Activation slope mantissa became zero due to exponent overflow fix")
            self.slopes_e = np.where(is_exp_too_big, ones * (2**exp_bits - 1), self.slopes_e)

    def _get_piecewise_params(self):
        return self.x_points, self.slopes, self.offsets, self.slopes_m, self.slopes_e

    def _remove_unused_end_pieces(self, limvals_out):
        if limvals_out is None:
            return
        unused_edges = 0
        for low, high in zip(self.y_points[-2::-1], self.y_points[::-1]):
            if (low >= limvals_out[1]) and (high >= limvals_out[1]):
                unused_edges += 1
            else:
                break
        self._remove_n_pieces_from_end(unused_edges)

    def _remove_n_pieces_from_end(self, count):
        if (len(self.y_points) - count) < 2:
            count = len(self.y_points) - 2
        if count <= 0:
            return
        self.x_points = self.x_points[:-count]
        self.y_points = self.y_points[:-count]
        self._x_points_diff = np.diff(self.x_points)
        self.offsets = self.offsets[:-count]
        self.slopes = self.slopes[:-count]

    def get_num_points(self):
        num_points = len(self._remove_duplicate_x_points(self.x_points))
        if self._should_add_delta_point:
            num_points += 1
        return num_points

    @classmethod
    def _inv_pos_minus_reciprocal_perm(cls, x_points, y_points):
        x_points = list(-np.array(x_points))[::-1]
        y_points = list(np.array(y_points))[::-1]
        return x_points, y_points

    @classmethod
    def get_default_fit_by_activation(cls, activation_fit, activation_type):
        """
        this function return whether to use the fitting activation optimization or not according to the 'activation_fit'
        parameter and the 'activation_type'.

        Args:
            activation_fit: the enum from the script
            activation_type: the activation type

        Returns: Bool

        """
        if activation_fit == ActivationFitPolicy.allowed:
            if activation_type in {ActivationType.exp, ActivationType.sqrt, ActivationType.log}:
                activation_fit = ActivationFitPolicy.enabled
            else:
                activation_fit = ActivationFitPolicy.disabled

        return activation_fit

    @classmethod
    def get_piecewise_activator(
        cls,
        activation_type,
        inter_layer_precision_mode,
        leaky_alpha=None,
        out_scale=None,
        activation_threshold=None,
        activation_delta_bias=None,
        pre_activation_limvals=None,
        activation_fit=None,
        limvals_out=None,
        activation_less_value=None,
        hardsigmoid_alpha=None,
        hardsigmoid_beta=None,
        clip_min=None,
        clip_max=None,
        activation_greater_value=None,
    ):
        activation_fit = cls.get_default_fit_by_activation(activation_fit, activation_type)
        if activation_fit == ActivationFitPolicy.enabled and activation_type in UNSUPPORTED_ACTIVATION_FIT:
            raise BackendNotImplementedError(
                f"the activation {activation_type.name} passed to piecewise_calculator cant be fitted",
            )
        if (activation_type in [ActivationType.sqrt, ActivationType.log]) and (pre_activation_limvals[0] < 0):
            raise BackendQuantizationException(f"{activation_type.value} activation receives negative input")
        if activation_fit == ActivationFitPolicy.enabled:
            # only for the sigmoid, exponent
            act = cls.piecewise_fitting(activation_type, inter_layer_precision_mode, pre_activation_limvals)
        else:
            # for all the activation types
            act = cls.piecewise_fixed_by_activation(
                activation_type,
                inter_layer_precision_mode,
                leaky_alpha,
                out_scale,
                activation_threshold,
                activation_delta_bias,
                activation_less_value,
                hardsigmoid_alpha,
                hardsigmoid_beta,
                clip_min,
                clip_max,
                activation_greater_value,
            )
            if act.continuous:
                act._remove_unused_end_pieces(limvals_out)

        return act

    @classmethod
    def elu_piecewise_activator(cls, inter_layer_precision_mode):
        y_points = [-1.0, -0.91, -0.71, -0.41, 0.0, 10.0]
        x_points = [-4.60517019, -2.30258509, -1.2039728, -0.51082562, 0.0, 10.0]
        return cls(y_points=y_points, x_points=x_points, inter_layer_precision_mode=inter_layer_precision_mode)

    @classmethod
    def relu_piecewise_activator(cls, inter_layer_precision_mode):
        y_points = [0.0, 0.0, 10.0]
        x_points = [-10.0, 0.0, 10.0]
        return cls(y_points=y_points, x_points=x_points, inter_layer_precision_mode=inter_layer_precision_mode)

    @classmethod
    def linear_piecewise_activator(cls, inter_layer_precision_mode):
        y_points = [0.0, 10.0]
        x_points = [0.0, 10.0]
        return cls(y_points=y_points, x_points=x_points, inter_layer_precision_mode=inter_layer_precision_mode)

    @classmethod
    def sigmoid_piecewise_activator(cls, inter_layer_precision_mode):
        y_points = [
            -0.0020313725490196085,
            0.03765490196078431,
            0.12496470588235295,
            0.26783529411764706,
            0.732164705882353,
            0.875035294117647,
            0.9623450980392158,
            1.0020313725490195,
        ]
        x_points = [-5.53733427, -3.09927295, -1.90616982, -0.99164017, 0.99164017, 1.90616982, 3.09927295, 5.53733427]
        return cls(y_points=y_points, x_points=x_points, inter_layer_precision_mode=inter_layer_precision_mode)

    @classmethod
    def piecewise_fixed_by_activation(
        cls,
        activation_type,
        inter_layer_precision_mode,
        leaky_alpha=None,
        out_scale=None,
        activation_threshold=None,
        activation_delta_bias=None,
        activation_less_value=None,
        hardsigmoid_alpha=None,
        hardsigmoid_beta=None,
        clip_min=None,
        clip_max=None,
        activation_greater_value=None,
    ):
        if activation_type == ActivationType.leaky:
            return cls.leaky_alpha_piecewise_activator(leaky_alpha, inter_layer_precision_mode)
        elif activation_type == ActivationType.elu:
            return cls.elu_piecewise_activator(inter_layer_precision_mode)
        elif activation_type == ActivationType.relu:
            return cls.relu_piecewise_activator(inter_layer_precision_mode)
        elif activation_type == ActivationType.linear:
            return cls.linear_piecewise_activator(inter_layer_precision_mode)
        elif activation_type == ActivationType.sigmoid:
            return cls.sigmoid_piecewise_activator(inter_layer_precision_mode)
        elif activation_type == ActivationType.exp:
            return cls.exp_piecewise_activator(inter_layer_precision_mode)
        elif activation_type == ActivationType.inv_pos:
            return cls.inverse_pos_piecewise_activator(inter_layer_precision_mode)
        elif activation_type == ActivationType.softplus:
            return cls.softplus_piecewise_activator(inter_layer_precision_mode)
        elif activation_type == ActivationType.relu6:
            return cls.relu_n_piecewise_activator(6.0, out_scale, inter_layer_precision_mode)
        elif activation_type == ActivationType.relu1:
            return cls.relu_n_piecewise_activator(1.0, out_scale, inter_layer_precision_mode)
        elif activation_type == ActivationType.tanh:
            return cls.tanh_piecewise_activator(inter_layer_precision_mode)
        elif activation_type == ActivationType.threshold:
            return cls.threshold_piecewise_activator(activation_threshold, inter_layer_precision_mode)
        elif activation_type == ActivationType.biased_delta:
            return cls.biased_delta_piecewise_activator(activation_delta_bias, inter_layer_precision_mode)
        elif activation_type == ActivationType.less:
            return cls.less_piecewise_activator(activation_less_value, inter_layer_precision_mode)
        elif activation_type == ActivationType.greater:
            raise BackendNotImplementedError(
                f"activation {activation_type.name} is not implemented in legacy quantization",
            )
        elif activation_type == ActivationType.silu:
            return cls.silu_piecewise_activator(inter_layer_precision_mode)
        elif activation_type == ActivationType.mish:
            return cls.mish_piecewise_activator(inter_layer_precision_mode)
        elif activation_type == ActivationType.hardsigmoid:
            return cls.hardsigmoid_piecewise_activator(hardsigmoid_alpha, hardsigmoid_beta, inter_layer_precision_mode)
        elif activation_type == ActivationType.gelu:
            return cls.gelu_piecewise_activator(inter_layer_precision_mode)
        elif activation_type == ActivationType.hardswish:
            return cls.hardswish_piecewise_activator(inter_layer_precision_mode)
        elif activation_type == ActivationType.clip:
            return cls.clip_piecewise_activator(inter_layer_precision_mode, clip_min, clip_max)

        # Raise a specific error if activation if part of ActivationType but it is not implemented
        elif activation_type in ActivationType:
            raise BackendNotImplementedError(
                "activation %s is not implemented in quantizer, not supported yet on hardware "
                % (activation_type.name),
            )

        # Raise an error in case activation is invalid
        else:
            raise BackendNotImplementedError(
                "activation (%s) passed to piecewise_calculator is from invalid type" % (str(activation_type)),
            )

    @classmethod
    def piecewise_fitting(
        cls,
        activation_type,
        inter_layer_precision_mode,
        pre_activation_limvals=None,
        n_segments=8,
        num_of_samples=350,
    ):
        """
        This method both optimizes returns the linear piecewise representation of the function.
        Note for inv_pos activation:
        In order to quantize 1/x, which has negative derivatives, we substitute it with -1/x and
        initialize the kernel with -1's instead of 1's. This means we reverse the x_points and
        y_points and multiply x_points by -1
        Args:
           activation_type (hailo_sdk_common.hailo_nn.hn_definitions.ActivationType): the activation type
           inter_layer_precision_mode (hailo_sdk_client.numeric_translator.inter_layer_precision_mode.InterLayerPrecisionMode):
           pre_activation_limvals (list): the max and min of the pre activation
           n_segments (int, optional): desired number of line segments
           num_of_samples (int, optional): number of samples for the optimization
        Returns:
           :class:`ModelParams <hailo_sdk_common.model_params.model_params.ModelParams>`:
           with the x_points and y_points.  (quantized) model parameters.
        """
        # this is to differ between the time that we only need the numer of segments and the time that
        # we need the actual points. It may differ between running with fitting and with no fitting,
        # we do want to return the correct number of points,
        # but not wasting the time on doing optimization.
        logger = default_logger()
        if pre_activation_limvals is not None:
            # then we need to get the actual points
            logger.debug(
                f"the func is {activation_type.value} the number of samples are: {num_of_samples} the number of pieces are {n_segments}",
            )
            func = get_function_by_activation(activation_type)
            x_points, y_points = cls.find_optimized_breaks(
                logger,
                func,
                n_segments,
                num_of_samples,
                pre_activation_limvals,
            )
            if activation_type == ActivationType.inv_pos:
                x_points, y_points = cls._inv_pos_minus_reciprocal_perm(x_points, y_points)
        else:
            # get only the number of points
            y_points = np.arange(n_segments + 1) + 10
            x_points = np.arange(n_segments + 1)
        return cls(y_points=y_points, x_points=x_points, inter_layer_precision_mode=inter_layer_precision_mode)

    @classmethod
    def exp_piecewise_activator(cls, inter_layer_precision_mode):
        y_points = [0.19583962, 0.39437257, 0.71859367, 1.07201579, 1.68125543, 2.50813837, 3.93354397]
        x_points = [-1.6, -0.9, -0.3, 0.1, 0.55, 0.95, 1.4]
        return cls(y_points=y_points, x_points=x_points, inter_layer_precision_mode=inter_layer_precision_mode)

    @classmethod
    def inverse_pos_piecewise_activator(cls, inter_layer_precision_mode):
        """
        in order to quantize 1/x, which has negative derivatives,
        we substitute it with -1/x and initialize the kernel with
        -1's instead of 1's. This means we reverse the x_points
        and y_points and multiply x_points by -1
        """
        x_points = [
            0.53052062,
            0.66966087,
            0.87466728,
            1.16936897,
            1.59553682,
            2.22181007,
            3.15596193,
            4.62858967,
            6.99250555,
        ]
        y_points = [
            1.87286626,
            1.4767212,
            1.12846775,
            0.84227976,
            0.61600893,
            0.44147993,
            0.30980502,
            0.2103742,
            0.13871488,
        ]
        x_points, y_points = cls._inv_pos_minus_reciprocal_perm(x_points, y_points)
        return cls(y_points=y_points, x_points=x_points, inter_layer_precision_mode=inter_layer_precision_mode)

    @classmethod
    def softplus_piecewise_activator(cls, inter_layer_precision_mode):
        x_points = [-10.0, -4.0, -2.33749224, -1.26575537, -0.40175493, 0.40645237, 1.27026613, 2.34072059, 4.0, 10]
        y_points = [
            0.0,
            0.01017738,
            0.08137196,
            0.23581468,
            0.49880582,
            0.90339035,
            1.50514079,
            2.42187092,
            4.01021905,
            10.01021905,
        ]
        return cls(y_points=y_points, x_points=x_points, inter_layer_precision_mode=inter_layer_precision_mode)

    @classmethod
    def silu_piecewise_activator(cls, inter_layer_precision_mode):
        """
        SiLU has negative slopes which is not supported in Hailo8. Hence we've created
        a piecewise linear fit, where the negative slopes are exchanged with constants approximation
        (slope=0). i.e, steps function. (1st line in x/y_points)
        The positive slope is approximated as usual: linear. (2nd line in x/y_points)
        The infinite slopes will be removed in '_remove_infinite_slopes' and will be left with 9 pieces
        """
        x_points = [-6.0, -5.0, -5.0, -3.388, -3.388, -2.339, -2.339, -1.0, -0.282, 0.384, 1.127, 5.0, 6.0]
        y_points = [0.0, 0.0, -0.066, -0.066, -0.156, -0.156, -0.263, -0.263, -0.14, 0.21, 0.835, 4.993, 6.0]
        return cls(y_points=y_points, x_points=x_points, inter_layer_precision_mode=inter_layer_precision_mode)

    @classmethod
    def mish_piecewise_activator(cls, inter_layer_precision_mode):
        """
        Mish has negative slopes which is not supported in Hailo8. Hence we've created
        a piecewise linear fit, where the negative slopes are exchanged with constants approximation
        (slope=0). i.e, steps function. (1st line in x/y_points)
        The positive slope is approximated as usual: linear. (2nd line in x/y_points)
        The infinite slopes will be removed in '_remove_infinite_slopes' and will be left with 9 pieces
        """
        x_points = [-6.0, -5.0, -5.0, -3.314, -3.314, -2.248, -2.248, -1.0, -0.21, 0.541, 2.72, 5.0, 6.0]
        y_points = [0.0, 0.0, -0.068, -0.068, -0.169, -0.169, -0.281, -0.281, -0.141, 0.385, 2.706, 5.003, 6.0]

        return cls(y_points=y_points, x_points=x_points, inter_layer_precision_mode=inter_layer_precision_mode)

    @classmethod
    def hardsigmoid_piecewise_activator(cls, alpha, beta, inter_layer_precision_mode):
        x_points = [-10.0, -(beta / alpha), (1 - beta) / alpha, 10.0]
        y_points = [0.0, 0.0, 1.0, 1.0]

        return cls(y_points=y_points, x_points=x_points, inter_layer_precision_mode=inter_layer_precision_mode)

    @classmethod
    def gelu_piecewise_activator(cls, inter_layer_precision_mode):
        x_points = [
            -6.0,
            -5.0,
            -5.0,
            -2.0751226149963,
            -2.0751226149963,
            -1.4004078406866,
            -1.4004078406866,
            -0.5,
            0.0659232450407,
            0.6541225171453,
            2.4055980282936,
            5.0,
            6.0,
        ]
        y_points = [
            0.0,
            0.0,
            -0.005706000151,
            -0.005706000151,
            -0.0731837008054,
            -0.0731837008054,
            -0.1530002471987,
            -0.1530002471987,
            0.014024374947,
            0.4672866620057,
            2.3966123906533,
            5.0034001838679,
            6.0,
        ]
        return cls(y_points=y_points, x_points=x_points, inter_layer_precision_mode=inter_layer_precision_mode)

    @classmethod
    def hardswish_piecewise_activator(cls, inter_layer_precision_mode):
        x_points = [
            -6.0,
            -5.0,
            -5.0,
            -2.78984109,
            -2.78984109,
            -2.29737651,
            -2.29737651,
            -1.0,
            0.35217667,
            1.70560031,
            3.13060766,
            5.0,
            6.0,
        ]
        y_points = [
            0.0,
            0.0,
            -0.00475815025,
            -0.00475815025,
            -0.190058066,
            -0.190058066,
            -0.347909906,
            -0.347909906,
            0.145858497,
            1.28680126,
            3.13060774,
            4.99999996,
            6.0,
        ]
        return cls(y_points=y_points, x_points=x_points, inter_layer_precision_mode=inter_layer_precision_mode)

    @classmethod
    def clip_piecewise_activator(cls, inter_layer_precision_mode, clip_min, clip_max):
        x_points = [clip_min - 1, clip_min, clip_max, clip_max + 1]
        y_points = [clip_min, clip_min, clip_max, clip_max]
        return cls(y_points=y_points, x_points=x_points, inter_layer_precision_mode=inter_layer_precision_mode)

    @classmethod
    def find_optimized_breaks(cls, logger, func, n_segments, num_of_samples, pre_activation_limvals):
        logger.info("activation fitting started")
        x_min, x_max = pre_activation_limvals[0], pre_activation_limvals[-1]
        x = np.linspace(x_min, x_max, num_of_samples)
        y = func(x)
        start_time = time.time()
        my_pwlf = pwlf.PiecewiseLinFit(x, y)
        x_points = my_pwlf.fit(n_segments=n_segments, seed=1)
        elapsed_time = time.time() - start_time
        logger.debug(f"activation fitting ended in {elapsed_time:.2f} seconds")
        y_points = my_pwlf.predict(x_points)
        y_points_real = func(x_points)
        logger.debug(f"the limvals are ({x_min},{x_max}) .")
        logger.debug(f"x = {x_points}")
        logger.debug(f"y = {y_points}")
        logger.debug(f"ssr = {my_pwlf.ssr}")
        logger.debug(f"error y_points predicted and y_points_real = {y_points_real - y_points}")
        return x_points, y_points

    @classmethod
    def relu_n_piecewise_activator(cls, n, out_scale, inter_layer_precision_mode):
        # TODO: replace by hw arch const
        if out_scale is None:
            max_out = n
        else:
            MAX_VALUE = 2**inter_layer_precision_mode.output_activation_bits - 1
            max_out = np.min([np.float32(MAX_VALUE * out_scale), n])
        y_points = [0.0, 0.0, max_out, max_out]
        x_points = [-10.0, 0.0, max_out, 10.0]
        return cls(y_points=y_points, x_points=x_points, inter_layer_precision_mode=inter_layer_precision_mode)

    @classmethod
    def tanh_piecewise_activator(cls, inter_layer_precision_mode):
        y_points = [-1.0, -0.99744147, -0.9363582, -0.78464631, -0.50597846, 0.56275497, 0.87365114, 0.99515878, 1.0]
        x_points = [-10.0, -2.6210944, -1.61987821, -1.02180768, -0.53728883, 0.60000386, 1.25945738, 2.28178297, 10.0]
        return cls(y_points=y_points, x_points=x_points, inter_layer_precision_mode=inter_layer_precision_mode)

    @classmethod
    def leaky_alpha_piecewise_activator(cls, leaky_alpha, inter_layer_precision_mode):
        if leaky_alpha is None:
            raise BackendQuantizationException("You must provide alpha to use Leaky Relu")
        y_points = [-10.0 * leaky_alpha, 0.0, 10.0]
        x_points = [-10.0, 0.0, 10.0]
        return cls(y_points=y_points, x_points=x_points, inter_layer_precision_mode=inter_layer_precision_mode)

    @classmethod
    def threshold_piecewise_activator(cls, threshold, inter_layer_precision_mode):
        if threshold is None:
            raise BackendQuantizationException("You must provide threshold to use threshold activation")
        if threshold > 100.0:
            raise BackendQuantizationException("threshold > 100 not supported")
        y_points = [0.0, 0.0, threshold, 100.0]
        x_points = [-100.0, threshold, threshold, 100.0]
        return cls(y_points=y_points, x_points=x_points, inter_layer_precision_mode=inter_layer_precision_mode)

    @classmethod
    def biased_delta_piecewise_activator(cls, activation_delta_bias, inter_layer_precision_mode):
        y_points = [activation_delta_bias, activation_delta_bias, activation_delta_bias]
        x_points = [-10.0, 0.0, 10.0]
        return cls(
            y_points=y_points,
            x_points=x_points,
            inter_layer_precision_mode=inter_layer_precision_mode,
            should_add_delta_point=True,
            continuous=False,
        )

    @classmethod
    def less_piecewise_activator(cls, val, inter_layer_precision_mode):
        val = np.array(val)
        if val.ndim != 0:
            unique_val = np.unique(val)
            if len(unique_val) != 1:
                raise BackendNotImplementedError("Less activation with multiple values in layer is not supported")
            else:
                val = unique_val[0]
        y_points = [1, 1, 0, 0]
        x_points = [val - 10, val, val, val + 10]
        return cls(y_points=y_points, x_points=x_points, inter_layer_precision_mode=inter_layer_precision_mode)

    @classmethod
    def get_piecewise_activator_scaled_params(
        cls,
        activation_type,
        hw_arch,
        accumulator_scale,
        qp_out,
        beta,
        mantissa_bits,
        exp_bits,
        leaky_alpha,
        zp_compensation,
        shift,
        is_apu_2s_complement,
        inter_layer_precision_mode,
        activation_threshold,
        activation_delta_bias,
        pre_activation_limvals=None,
        activation_fit=None,
        quantization_groups=1,
        limvals_out=None,
        signed_output=False,
        activation_less_value=None,
        hardsigmoid_alpha=None,
        hardsigmoid_beta=None,
        clip_min=None,
        clip_max=None,
        activation_greater_value=None,
    ):
        act = cls.get_piecewise_activator(
            activation_type,
            inter_layer_precision_mode,
            leaky_alpha=leaky_alpha,
            out_scale=qp_out.scale,
            activation_threshold=activation_threshold,
            activation_delta_bias=activation_delta_bias,
            pre_activation_limvals=pre_activation_limvals,
            activation_fit=activation_fit,
            limvals_out=limvals_out,
            activation_less_value=activation_less_value,
            hardsigmoid_alpha=hardsigmoid_alpha,
            hardsigmoid_beta=hardsigmoid_beta,
            clip_min=clip_min,
            clip_max=clip_max,
            activation_greater_value=activation_greater_value,
        )

        # get the points before scaling
        x_points_before_scale, y_points_before_scale, slopes_before_scale, offsets_before_scale = (
            act.x_points,
            act.y_points,
            act.slopes,
            act.offsets,
        )

        if quantization_groups > 1 and not hw_arch.is_mercury_arch:
            if quantization_groups > MAX_ALLOWED_QUANTIZATION_GROUPS:
                raise BackendQuantizationException("Number of quantization groups exceeds maximum")
            if len(act.x_points) - 1 > MAX_ALLOWED_PIECE_NUMBER:
                raise BackendQuantizationException("Number of pieces exceeds maximum")

        act._set_scales(
            accumulator_scale,
            qp_out,
            beta,
            mantissa_bits,
            exp_bits,
            zp_compensation,
            shift,
            is_apu_2s_complement,
            signed_output=signed_output,
            pre_act_limvals=pre_activation_limvals,
        )
        x_points, slopes, offsets, slopes_m, slopes_e = act._get_piecewise_params()

        return (
            x_points,
            slopes,
            offsets,
            slopes_m,
            slopes_e,
            x_points_before_scale,
            y_points_before_scale,
            slopes_before_scale,
            offsets_before_scale,
        )


def get_function_by_activation(activation_type):
    def sigmoid(x):
        return 1 / (1 + np.exp(-x))

    def inverse_activation(x):
        return 1.0 / x

    def softplus(x):
        return np.log1p(np.exp(-np.abs(x))) + np.maximum(x, 0)

    if activation_type == ActivationType.exp:
        return np.exp
    elif activation_type == ActivationType.sqrt:
        return np.sqrt
    elif activation_type == ActivationType.log:
        return np.log
    elif activation_type == ActivationType.inv_pos:
        return inverse_activation
    elif activation_type == ActivationType.sigmoid:
        return sigmoid
    elif activation_type == ActivationType.softplus:
        return softplus
    else:
        raise BackendQuantizationException(f"we do not support fitting to the {activation_type} activation")
