import time
from copy import copy, deepcopy

import numpy as np
from past.utils import old_div
from scipy.stats import gmean

from hailo_model_optimization.acceleras.utils.acceleras_definitions import EqualizationMode
from hailo_sdk_client.quantization.tools.optimize_kernel_ranges import mmse
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType, LayerType
from hailo_sdk_common.logger.logger import default_logger


class FactorsProducer:
    def __init__(
        self,
        factors_min,
        factors_max,
        kernel_factor_min,
        activations_factor_min,
        post_activations_factor_min,
        pre_activations_factor_min,
        kernel_factor_max,
        activations_factor_max,
    ):
        self.factors_min = factors_min
        self.factors_max = factors_max
        self.kernel_factor_min = kernel_factor_min
        self.activations_factor_min = activations_factor_min
        self.pre_activations_factor_min = pre_activations_factor_min
        self.post_activations_factor_min = post_activations_factor_min
        self.kernel_factor_max = kernel_factor_max
        self.activations_factor_max = activations_factor_max


class FactorsConsumers:
    def __init__(self, kernel_factor_min, kernel_factor_max):
        self.kernel_factor_min = kernel_factor_min
        self.kernel_factor_max = kernel_factor_max


class FactorsCalculatorError(Exception):
    pass


class FactorsCalculator:
    EPSILON = 1e-30
    RELU6_MINIMUM_SCALE = 0.7  # This number is chosen randomly and works

    def __init__(self, equiv_set, stats_equiv_layers, max_concat, configuration):
        self._equiv_set = equiv_set
        self._stats_equiv_layers = stats_equiv_layers
        self._max_concat = max_concat
        self._configuration = configuration
        self._set_indices_relu6 = equiv_set.set_indices_relu6
        self.RELU_SENSITIVE = True

    @property
    def max_activation_factor(self):
        return self._configuration.max_activation_factor

    @property
    def consumers(self):
        return self._equiv_set.consumers

    @property
    def producers(self):
        return self._equiv_set.producers

    @property
    def number_of_factors(self):
        return len(self._equiv_set.unique_indices)

    @property
    def mode(self):
        return self._configuration.mode

    @staticmethod
    def timer(start, end):
        hours, rem = divmod(end - start, 3600)
        minutes, seconds = divmod(rem, 60)

        default_logger().debug(f"Ending time for algo {int(hours):0>2}:{int(minutes):0>2}:{seconds:05.2f} ")

    @classmethod
    def algorithm_information(cls, start, end, delta, iteration, sol):
        default_logger().debug(f"In iteration {iteration} diff: {delta} sol: {sol:0.7f}")
        cls.timer(start, end)

    @property
    def two_stage(self):
        return self._configuration.two_stage

    def get_factors(self):
        """
        the main function ti return the factor we need for equalization
        """
        if self.mode == EqualizationMode.kernel_equalization:
            factors_solution, sol_value, initial_sol_value, no_algo_sol = self._get_kernel_equalization_factors()
        elif self.mode == EqualizationMode.min_based:
            factors_solution, sol_value, initial_sol_value, no_algo_sol = self._get_min_based_factors()
        elif self.mode == EqualizationMode.noise_based:
            factors_solution, sol_value, initial_sol_value, no_algo_sol = self._get_noise_based_factors()
            default_logger().debug("############ statistics after algo #############")
            self._get_algorithm_statistics(factors_solution)
        else:
            raise FactorsCalculatorError(f"The algo type for equalization got unknown value {self.mode} .")

        return factors_solution, sol_value, initial_sol_value, no_algo_sol

    # algorithms
    def _get_kernel_equalization_factors(self):
        """
        Use a skewed-geometric-mean heuristic to balance
        between the factor vectors "desired" by each of the kernels involved,
        for its individual quantization optimization (roughly, layerwise/channelwise ratios)
        We average the logs of the ratios, giving higher weighting to the more sensitive 4b layers.
        """
        number_of_factors = self.number_of_factors
        desired_shifts_consumer, denominator_consumers = self._get_one_side_shifts(number_of_factors, self.consumers)
        desired_shifts_producer, denominator_producers = self._get_one_side_shifts(number_of_factors, self.producers)
        factors = 2 ** (
            old_div(desired_shifts_producer + desired_shifts_consumer, denominator_consumers + denominator_producers)
        )
        factors = np.clip(factors, 1 / 8, 8)
        factors = self._set_minimal_factor_for_relu6_set_indices(factors)
        factors = np.array(factors, dtype=np.float32)
        producers_factors, consumers_factors = self._get_max_based_data()
        # collect min_based results for statistics and debug
        min_based_sol, no_equalization_sol = self._print_debug_information(producers_factors, consumers_factors)
        kernel_equalization_sol = self.calc_equiv_noise(factors)

        return factors, kernel_equalization_sol, min_based_sol, no_equalization_sol

    def _get_min_based_factors(self):
        """
        return the factor we need for equalization by maximizing each channel to as best as we can
        """
        producers_factors, consumers_factors = self._get_max_based_data()
        min_based_sol, no_equalization_sol = self._print_debug_information(producers_factors, consumers_factors)
        factors = np.array(producers_factors.factors_min, dtype=np.float32)
        return factors, min_based_sol, min_based_sol, no_equalization_sol

    def _get_noise_based_factors(self):
        """
        return the factors by the following heuristic:
            1. start with the factors we can do by the min_based algo
            2. go over each channel and try to stretch the factors a bit- and continue until the noise will not
            improve

        """
        producers_factors, consumers_factors = self._get_max_based_data()
        min_based_sol, no_equalization_sol = self._print_debug_information(producers_factors, consumers_factors)

        min_factors = producers_factors.factors_min
        max_factors = producers_factors.factors_max
        default_logger().debug(f"The min_initial_sol_vector limvals [{np.min(min_factors)}, {np.max(min_factors)}]")
        default_logger().debug(f"The max_initial_sol_vector limvals [{np.min(max_factors)}, {np.max(max_factors)}]")
        default_logger().debug(f"The number of factors is: {self.number_of_factors}")

        if min_based_sol > no_equalization_sol:
            default_logger().debug("#################### we are using x0 as default ######################")
            initial_sol_value = no_equalization_sol
            initial_vector = np.ones(self.number_of_factors)
        else:
            initial_sol_value = min_based_sol
            initial_vector = min_factors

        default_logger().debug("######## statistics before noise_based algo (min_based vs no_algo) ##############")
        self._get_algorithm_statistics(min_factors)
        sol_value = initial_sol_value
        sol_vector = copy(initial_vector)

        if len(self._set_indices_relu6) != 0:
            default_logger().debug(
                f"we are using the regular algo because there is relu 6 here {len(self._set_indices_relu6)}",
            )
            factors_solution = np.array(min_factors, dtype=np.float32)
            return factors_solution, min_based_sol, min_based_sol, no_equalization_sol

        max_factor_update = copy(max_factors)
        sol_value_beginning = sol_value - 10
        time_start = time.time()
        iteration = 1
        desired_producers_factors = deepcopy(self._factors_producers)
        while np.abs(sol_value - sol_value_beginning) > 3 and iteration < 2:
            sol_value_beginning = sol_value
            sol_vector, max_factor_update = self._update_factors_by_noise(
                sol_vector,
                desired_producers_factors,
                max_factor_update,
                iteration,
            )
            sol_value = self.calc_equiv_noise(sol_vector)
            iteration += 1

        end_time = time.time()
        factors_solution = np.array(sol_vector, dtype=np.float32)
        default_logger().debug(f"The sol_vector limvals [{np.min(sol_vector)}, {np.max(sol_vector)}]")
        default_logger().debug(f"The max_factor limvals [{np.min(max_factor_update)}, {np.max(max_factor_update)}]")
        delta = sol_value - initial_sol_value
        self.algorithm_information(time_start, end_time, delta, iteration - 1, sol_value)

        return factors_solution, sol_value, min_based_sol, no_equalization_sol

    def _print_debug_information(self, producers_factors, consumers_factors):
        x0 = np.ones(self.number_of_factors)
        no_equalization_sol = self.calc_equiv_noise(x0)
        min_based_sol = self.calc_equiv_noise(producers_factors.factors_min)
        if self.mode == EqualizationMode.min_based:
            return min_based_sol, no_equalization_sol
        default_logger().debug("producers")
        for key, value in vars(producers_factors).items():
            default_logger().debug(f"func on {key}: {self.calc_equiv_noise(value):.3f}")
        default_logger().debug("consumers")
        for key, value in vars(consumers_factors).items():
            default_logger().debug(f"func on {key}: {self.calc_equiv_noise(value):.3f}")
        default_logger().debug(f"func with no equalization {no_equalization_sol}")

        return min_based_sol, no_equalization_sol

    # utils for max based algo
    def _get_max_based_data(self):
        number_of_factors = self.number_of_factors
        consumers_factors = self._get_consumers_factors(number_of_factors)
        producers_factors = self._get_producers_factors(number_of_factors)

        return producers_factors, consumers_factors

    def _get_producers_factors(self, number_of_factors):
        """
        get all the producers factors: the factor_kernel, post_activation, pre_activation
        Args:
            number_of_factors: the number of unique indices
        Returns: final_kernel_factor, final_post_activation_factor, final_pre_activation_factor

        """
        self._factors_producers = {"kernel": {}, "post_activation": {}, "pre_activation": {}}
        for equiv_layer in self.producers:
            equiv_name = equiv_layer.equiv_name
            kernel_factor_consumer = self._get_consumers_factors_of_current_producer(
                equiv_layer,
                self._factors_consumers["kernel"],
            )

            two_stage_factors = self._get_two_stage_factors(equiv_layer, kernel_factor_consumer)
            kernel_factor = self._get_kernel_max_based_factors(equiv_layer, two_stage_factors=two_stage_factors)
            pre_activation, post_activation = self._get_activation_factors(
                equiv_layer,
                two_stage_factors=two_stage_factors,
            )
            kernel_factor, post_activation, pre_activation = self._cross_layers_equalization(
                equiv_layer,
                kernel_factor,
                post_activation,
                pre_activation,
            )

            self._factors_producers["kernel"][equiv_name] = self._get_padded_layer_factors(
                number_of_factors,
                equiv_layer,
                kernel_factor,
                to_bound=True,
            )

            self._factors_producers["post_activation"][equiv_name] = self._get_padded_layer_factors(
                number_of_factors,
                equiv_layer,
                post_activation,
                to_bound=True,
            )

            self._factors_producers["pre_activation"][equiv_name] = self._get_padded_layer_factors(
                number_of_factors,
                equiv_layer,
                pre_activation,
                to_bound=True,
            )

        kernel_factor_min = self._get_factors_or_default(
            number_of_factors,
            self._factors_producers["kernel"],
            apply_relu=True,
        )
        post_activations_factor_min = self._get_factors_or_default(
            number_of_factors,
            self._factors_producers["post_activation"],
            apply_relu=True,
        )
        pre_activations_factor_min = self._get_factors_or_default(
            number_of_factors,
            self._factors_producers["pre_activation"],
            apply_relu=True,
        )

        kernel_factor_max = self._get_factors_or_default(
            number_of_factors,
            self._factors_producers["kernel"],
            callback=np.nanmax,
        )
        post_activation_factor_max = self._get_factors_or_default(
            number_of_factors,
            self._factors_producers["post_activation"],
            callback=np.nanmax,
            apply_relu=True,
        )
        pre_activation_factor_max = self._get_factors_or_default(
            number_of_factors,
            self._factors_producers["pre_activation"],
            callback=np.nanmax,
            apply_relu=True,
        )
        activations_factor_max = np.maximum(post_activation_factor_max, pre_activation_factor_max)

        factors_max = np.maximum(kernel_factor_max, activations_factor_max)

        activations_factor_min = np.minimum(post_activations_factor_min, pre_activations_factor_min)
        factors_min = np.minimum(kernel_factor_min, activations_factor_min)
        factors_min = self._set_minimal_factor_for_relu6_set_indices(factors_min)

        return FactorsProducer(
            factors_min=factors_min,
            factors_max=factors_max,
            kernel_factor_min=kernel_factor_min,
            activations_factor_min=activations_factor_min,
            post_activations_factor_min=post_activations_factor_min,
            pre_activations_factor_min=pre_activations_factor_min,
            kernel_factor_max=kernel_factor_max,
            activations_factor_max=activations_factor_max,
        )

    @staticmethod
    def _get_consumers_factors_of_current_producer(equiv_layer, factors_consumers):
        consumers_of_prod = [
            equiv_layer
            for equiv_layer in equiv_layer.following_consumers
            if equiv_layer.layer.op != LayerType.batch_norm
        ]
        if len(consumers_of_prod) == 0:
            return None
        names = " ,".join([consumer.layer_name for consumer in consumers_of_prod])
        default_logger().debug(f"the consumers of {equiv_layer.layer_name} are {names}")
        final_kernel_c_factor = np.concatenate(
            [factors_consumers[consumer.equiv_name].reshape(1, -1) for consumer in consumers_of_prod],
        )
        return np.nanmin(final_kernel_c_factor, axis=0)

    def _get_consumers_factors(self, number_of_factors):
        self._factors_consumers = {"kernel": {}}
        if len(self.consumers) == 0:
            return np.full(number_of_factors, np.nan)

        for equiv_layer in self.consumers:
            equiv_name = equiv_layer.equiv_name
            kernel_factor = self._get_kernel_max_based_factors(equiv_layer)
            self._factors_consumers["kernel"][equiv_name] = self._get_padded_layer_factors(
                number_of_factors,
                equiv_layer,
                kernel_factor,
            )

        final_kernel_c_factor = np.concatenate(
            [factor.reshape(1, -1) for factor in self._factors_consumers["kernel"].values()],
        )
        kernel_factor_min = np.nanmin(final_kernel_c_factor, axis=0)
        kernel_factor_max = np.nanmax(final_kernel_c_factor, axis=0)

        return FactorsConsumers(kernel_factor_min=kernel_factor_min, kernel_factor_max=kernel_factor_max)

    def _get_activation_factors(self, equiv_layer, two_stage_factors):
        stats_by_layer = self._get_data_for_equiv_layer(equiv_layer)
        indices = np.array(equiv_layer.layer_indices)
        post_activation_max = stats_by_layer["post_activation_max"]
        post_activation_min = stats_by_layer["post_activation_min"]

        max_post_act = np.nanmax(post_activation_max * two_stage_factors)
        min_post_act = np.nanmin(post_activation_min * two_stage_factors)
        # (A) Calculating the "activations-limited" factors vector as layerwise/channelwise ranges ratio,
        #   capped by a constant (normally 16) for robustness.
        post_act_max_factor = old_div(max_post_act, post_activation_max + self.EPSILON)
        post_act_max_factor[post_activation_max <= self.EPSILON] = self.max_activation_factor
        post_act_min_factor = old_div(min_post_act, post_activation_min + self.EPSILON)
        post_act_min_factor[post_activation_min >= -self.EPSILON] = self.max_activation_factor
        post_act_factor = np.minimum(post_act_max_factor, post_act_min_factor)
        post_act_factor = np.minimum(post_act_factor, self.max_activation_factor)

        # (!) (B) Same for "pre-activation-suggested" factors vectors..
        # As elsewhere, pre-activations are checked because the activation statistics may hide large negative
        # pre-activation values, which can cause SC accumulator overflow (..or aggressive preventive action of
        # reducing kernel bits @ quant-time).

        pre_activation_min = stats_by_layer["pre_activation_min"]
        pre_act_factor = old_div(np.abs(max_post_act), np.abs(pre_activation_min) + self.EPSILON)
        pre_act_factor = np.maximum(pre_act_factor, 1.0)

        if not np.all(pre_act_factor > 0):
            raise FactorsCalculatorError(f"pre_act_factor are not positive for {equiv_layer.layer.name}.")
        if not np.all(post_act_factor > 0):
            raise FactorsCalculatorError(f"post_act_facto are not positive for {equiv_layer.layer.name}.")

        return pre_act_factor[indices], post_act_factor[indices]

    def _get_kernel_max_based_factors(self, equiv_layer, two_stage_factors=None):
        """Divide the max of the kernel by the max of the channels"""
        stats_by_layer = self._get_data_for_equiv_layer(equiv_layer)
        indices = np.array(equiv_layer.layer_indices)
        kernel_max_channel = self._get_kernels_channels_range(stats_by_layer)
        if two_stage_factors is None:
            kernel_max = np.max(kernel_max_channel)
        else:
            kernel_max = np.nanmax(kernel_max_channel * two_stage_factors)
        if equiv_layer.is_producer():
            all_factors = old_div(kernel_max + self.EPSILON, (kernel_max_channel + self.EPSILON))
        else:
            all_factors = old_div(kernel_max_channel + self.EPSILON, kernel_max + self.EPSILON)

        if not np.all(all_factors > 0):
            raise FactorsCalculatorError(f"kernel factors are not positive for {equiv_layer.layer.name}.")

        return all_factors[indices]

    def _get_two_stage_factors(self, equiv_layer, two_stage_factors):
        # for backwards compatibility
        """Get values for max_kernels_per_input_channel based on all of the equiv-layer consumers"""
        stats_by_layer = self._get_data_for_equiv_layer(equiv_layer)
        padded_max_kernels_per_input_channel = np.ones_like(stats_by_layer["post_activation_max"])
        if two_stage_factors is None or not self.two_stage:
            return padded_max_kernels_per_input_channel
        set_indices = self._equiv_set.get_set_indices_of_equiv_layer(equiv_layer)
        padded_max_kernels_per_input_channel[np.array(equiv_layer.layer_indices)] = two_stage_factors[set_indices]
        padded_max_kernels_per_input_channel[np.isnan(padded_max_kernels_per_input_channel)] = np.ones_like(
            len(np.isnan(padded_max_kernels_per_input_channel)),
        )
        return padded_max_kernels_per_input_channel

    @classmethod
    def _get_kernels_channels_range(cls, stats_by_layer):
        """Get max value per channel"""
        kernel = stats_by_layer["kernel"]
        axes_to_max = stats_by_layer["axes_to_max"]
        axes_to_max = tuple(np.where(axes_to_max)[0])
        return np.max(np.abs(kernel), axis=axes_to_max)

    def _cross_layers_equalization(self, equiv_layer, kernel_factor, post_activation_factor, pre_activation_factor):
        """
        Take into account the other layers this layer is concatenated with - replace simple "layerwise max"
        (the one we want to avoid increasing) by maxing over all "concat companions", because during "scale
        matching" (that happens after Equalization) ranges will anyways all be synced to largest one.
        """
        if self._max_concat is None:
            return kernel_factor, post_activation_factor, pre_activation_factor
        # TODO: a potential problem when the  the concat layers are in the same equiv class

        stats_by_layer = self._get_data_for_equiv_layer(equiv_layer)

        layer_scale = np.minimum(kernel_factor, np.minimum(post_activation_factor, pre_activation_factor))
        post_activation_max = stats_by_layer["post_activation_max"]

        padded_layer_scale = np.ones_like(post_activation_max)
        padded_layer_scale[np.array(equiv_layer.layer_indices)] = layer_scale

        max_value = np.max(post_activation_max)
        # AF: new_max will be exactly to max_val at least for 1-stage?
        new_max = np.max(post_activation_max * padded_layer_scale)

        max_val = np.maximum(max_value, self._max_concat)
        scale_matching = old_div(max_val, (new_max + self.EPSILON))

        kernel_factor *= scale_matching
        post_activation_factor *= scale_matching
        pre_activation_factor *= scale_matching

        if not np.all(kernel_factor > 0):
            raise FactorsCalculatorError(f"kernel factors are not positive for {equiv_layer.layer.name}.")
        if not np.all(post_activation_factor > 0):
            raise FactorsCalculatorError(f"post_activation_factor are not positive for {equiv_layer.layer.name}.")
        if not np.all(pre_activation_factor > 0):
            raise FactorsCalculatorError(f"pre_activation_factor re not positive for  {equiv_layer.layer.name}.")

        return kernel_factor, post_activation_factor, pre_activation_factor

    def _get_factors_or_default(
        self,
        number_of_factors,
        factor_dict,
        default_value=np.nan,
        callback=np.nanmin,
        apply_relu=False,
    ):
        # if the factor does not exist, then return an array with default values
        factors_to_return = np.full(number_of_factors, default_value)
        if factor_dict:
            factors_to_return = callback(
                np.concatenate([factor.reshape(1, -1) for factor in factor_dict.values()]),
                axis=0,
            )
            if apply_relu:
                factors_to_return = self._set_minimal_factor_for_relu6_set_indices(factors_to_return)
        return factors_to_return

    def _get_padded_layer_factors(
        self,
        number_of_factors,
        equiv_layer,
        layer_factor,
        default_value=np.nan,
        callback=np.nanmin,
        to_bound=False,
    ):
        """
        return the factors of the layer - consumer or producer in the length of number_of_factors
        and all the rest are as default value
        """
        layer_factor_padded = np.full(number_of_factors, default_value)
        set_indices = self._equiv_set.get_set_indices_of_equiv_layer(equiv_layer)
        unique_set_indices, layer_factor_unique = self._equiv_set.handle_set_indices_conflicts(
            set_indices,
            layer_factor,
            callback,
        )
        layer_factor_padded[unique_set_indices] = layer_factor_unique
        if to_bound:
            layer_factor_padded = np.minimum(layer_factor_padded, self.max_activation_factor)
        return layer_factor_padded

    def _get_data_for_equiv_layer(self, equiv_layer):
        """Get the relevant stats of the equiv layer"""
        return self._stats_equiv_layers[equiv_layer.type_of_layer][equiv_layer.layer_name]

    def _set_minimal_factor_for_relu6_set_indices(self, factors):
        """
        for all set_indices that have relu6 we want to set a minimal factor
        """
        set_indices_relu6 = self._set_indices_relu6
        if len(set_indices_relu6) != 0:
            factors[set_indices_relu6] = np.maximum(factors[set_indices_relu6], type(self).RELU6_MINIMUM_SCALE)
        return factors

    # utils for based_noise algo
    def _get_new_desired_min_factors_after_update(self, index, update_scale, new_factor, desired_min_factors):
        for producer in self.producers:
            equiv_name = producer.equiv_name
            kernel_factors = desired_min_factors["kernel"][equiv_name]
            factors_post_act = desired_min_factors["post_activation"][equiv_name]
            factors_pre_act = desired_min_factors["pre_activation"][equiv_name]
            act_prod = np.minimum(factors_post_act, factors_pre_act)

            # we now go over all the desired factors and if the new factor will bigger than the old factor
            # the max of this index will be bigger than the old max and thus we should update
            # the factors of all the indices accordingly

            if kernel_factors[index] < new_factor:
                desired_min_factors["kernel"][equiv_name] = np.minimum(
                    kernel_factors * update_scale,
                    self.max_activation_factor,
                )
            if act_prod[index] < new_factor:
                desired_min_factors["post_activation"][equiv_name] = np.minimum(
                    factors_post_act * update_scale,
                    self.max_activation_factor,
                )
                desired_min_factors["pre_activation"][equiv_name] = np.minimum(
                    np.maximum(factors_pre_act * update_scale, 1.0),
                    self.max_activation_factor,
                )

        final_p_factor = np.concatenate(
            [
                factor.reshape(1, -1)
                for current_factors_type in desired_min_factors.values()
                for factor in current_factors_type.values()
            ],
        )

        max_factor = np.nanmax(final_p_factor, axis=0)
        min_factor = np.nanmin(final_p_factor, axis=0)
        return min_factor, max_factor

    def _update_factors_by_noise(
        self,
        factors_solution,
        current_desired_producers_factors,
        maximal_desired_factors,
        iteration_num,
        number_of_rounds=10,
    ):
        initial_sol_value = self.calc_equiv_noise(factors_solution)
        new_algo_sol = initial_sol_value
        for factor_index in range(self.number_of_factors):
            initial_sol_value_per_channel = self.calc_equiv_noise(factors_solution)
            current_factor = factors_solution[factor_index]
            maximal_factor = maximal_desired_factors[factor_index]
            factors_to_be_checked = np.linspace(current_factor, maximal_factor, number_of_rounds)
            for round_number, updated_factor in enumerate(factors_to_be_checked[1:]):
                # go over all potential updates of the factor starting from the min factor and ending in the
                # max factor of the current index  in fixed steps
                required_scale = old_div(updated_factor, current_factor)
                desired_producers_factors = deepcopy(current_desired_producers_factors)
                new_factors, new_max = self._get_new_desired_min_factors_after_update(
                    factor_index,
                    required_scale,
                    updated_factor,
                    desired_producers_factors,
                )

                new_sol = self.calc_equiv_noise(new_factors)
                continue_g = new_sol < new_algo_sol
                if not continue_g:
                    if round_number != 0:
                        diff = new_sol - new_algo_sol
                        default_logger().debug(
                            f"stopping in [channel: {factor_index} round: {round_number}, "
                            f"iter: {iteration_num}] due to degradation of {diff} from  "
                            f"{initial_sol_value_per_channel:.3f} ==>> {new_algo_sol:.3f} ",
                        )
                    break
                if round_number == number_of_rounds - 1 and new_sol < new_algo_sol:
                    default_logger().debug(
                        f"changing in [channel: {factor_index}, round: {round_number},"
                        f"iter: {iteration_num} ] from "
                        f"{initial_sol_value_per_channel:3f} ==> {new_sol:3f} ",
                    )
                new_algo_sol = new_sol
                factors_solution = new_factors
                maximal_desired_factors = new_max
                current_factor = updated_factor
                current_desired_producers_factors = desired_producers_factors

        if new_algo_sol != initial_sol_value:
            default_logger().debug(
                f"we changed in iteration {iteration_num} from  {initial_sol_value:.3f} ==>> {new_algo_sol:.3f} ",
            )
        else:
            default_logger().debug(f"no change in iteration {iteration_num} ")
        return factors_solution, maximal_desired_factors

    # all function for kernel_equalization

    def _get_kernel_factors_optimized(self, equiv_layer):
        """
        divide the "optimizes max kernel" by the "optimizes max the channels
        """
        new_max_kernel = self._get_optimal_kernel_range(equiv_layer)
        kernel_max_channel = self._get_optimal_channels_range(equiv_layer)
        if equiv_layer.is_producer():
            return old_div(new_max_kernel, (kernel_max_channel + self.EPSILON))
        return old_div(kernel_max_channel, new_max_kernel)

    def _get_optimal_kernel_range(self, equiv_layer):
        """Get optimized kernel range"""
        stats_by_layer = self._get_data_for_equiv_layer(equiv_layer)
        eps = 1e-6
        kernel = stats_by_layer["kernel"]
        bits = stats_by_layer["number_bits"]
        return mmse(kernel + eps, bits) + eps

    def _get_optimal_channels_range(self, equiv_layer):
        """Get optimal channel range for all channels"""
        stats_by_layer = self._get_data_for_equiv_layer(equiv_layer)
        indices = np.array(equiv_layer.layer_indices)
        eps = 1e-6
        kernel = stats_by_layer["kernel"]
        axes_to_max = np.array(stats_by_layer["axes_to_max"], dtype=np.int32)
        axis = int(np.where(axes_to_max == 0)[0])

        bits = stats_by_layer["number_bits"]
        max_channels = []
        for index in indices:
            max_channels.append(mmse(np.take((kernel + eps), index, axis=axis), bits) + eps)

        return np.array(max_channels)

    def _get_one_side_shifts(self, number_of_factors, equiv_layers_list):
        desired_shifts = {}
        denominator_by_equiv_layer = {}
        for equiv_layer in equiv_layers_list:
            equiv_name = equiv_layer.equiv_name
            kernels_desired_shifts, denominator_by_equiv_layer[equiv_name] = self._get_kernels_desired_shifts(
                equiv_layer,
            )
            desired_shifts[equiv_name] = self._get_padded_layer_factors(
                number_of_factors,
                equiv_layer,
                kernels_desired_shifts,
                callback=np.nansum,
            )

        final_shifts_one_side = self._get_factors_or_default(number_of_factors, desired_shifts, callback=np.nansum)
        final_denominator_one_side = self._get_denominator(
            number_of_factors,
            desired_shifts,
            denominator_by_equiv_layer,
        )

        return final_shifts_one_side, final_denominator_one_side

    def _log_2_factors(self, equiv_layer, kernel_factor):
        """
        We sum the logs of the ratios(kernel_factors), giving higher weighting to the more sensitive 4b layers.

        Args:
            equiv_layer:
            kernel_factor:

        Returns:
            desired_shifts: an array in the size of the kernel_factors - for each channel we get the log2 of the factor,
                            clipping high/low values
            skew_factor:  an array in the size of the kernel_factors with "weight" we give to channel

        """
        stats = self._get_data_for_equiv_layer(equiv_layer)
        layer_name = equiv_layer.layer_name
        indices = np.array(equiv_layer.layer_indices)
        desired_shifts = np.log2(kernel_factor)
        # Ignore / moderate too large shifts.

        if equiv_layer.is_producer() and np.any(desired_shifts > 8):
            ignored_indices = np.where(desired_shifts > 8)[0]
            desired_shifts[ignored_indices] = np.nan * len(ignored_indices)
            default_logger().debug(
                f"Seems like a degenerate output channel #{indices[ignored_indices]} in {layer_name}!",
            )
        if equiv_layer.is_consumer() and np.any(desired_shifts < -8):
            ignored_indices = np.where(desired_shifts < -8)[0]
            desired_shifts[ignored_indices] = np.nan * len(ignored_indices)

            default_logger().debug(
                f"Seems like a degenerate input channel #{indices[ignored_indices]} in {layer_name}!",
            )
        desired_shifts = np.clip(desired_shifts, -4, 4)

        skew_factor = 3 if stats["number_bits"] == 4 else 1
        desired_shifts = desired_shifts * skew_factor  # Skew in favor of 4b layers
        return desired_shifts, skew_factor

    def _get_kernels_desired_shifts(self, equiv_layer):
        """
        calculate each of the kernels involved - its individual quantization optimization
        (roughly, layerwise/channelwise ratios) and the logs of the ratios.
        """
        kernel_factor = self._get_kernel_factors_optimized(equiv_layer)
        return self._log_2_factors(equiv_layer, kernel_factor)

    @staticmethod
    def _get_denominator(number_of_factors, factor_dict, denominator_dict):
        factor_dict_copy = copy(factor_dict)
        factor_default = np.full(number_of_factors, np.nan)
        if factor_dict:
            # replace all not nan values in the array with the denominator of the layer.
            for key in factor_dict_copy:
                factor_dict_copy[key][~np.isnan(factor_dict_copy[key][0])] = denominator_dict[key]
            return np.nansum(
                np.concatenate([factor_dict_copy[key].reshape(1, -1) for key in factor_dict]),
                axis=0,
            )
        return factor_default

    @staticmethod
    def _get_geometric_mean(kernel_factor_producer, kernel_factor_consumers, configuration):
        sum_factors = configuration.alpha_ker_producer + configuration.alpha_ker_consumer
        array_1 = np.power(kernel_factor_producer, configuration.alpha_ker_producer / sum_factors).reshape(1, -1)
        array_2 = np.power(kernel_factor_consumers, configuration.alpha_ker_consumer / sum_factors).reshape(1, -1)

        concatenation = np.concatenate([array_1, array_2])
        geometric_mean = np.prod(concatenation, axis=0, keepdims=True)
        # we may get that some of the values are nan for get the values of the not nan elements in the array
        non_nan_values = np.nanmin(concatenation, axis=0, keepdims=True)[np.isnan(geometric_mean)]
        geometric_mean[np.isnan(geometric_mean)] = non_nan_values
        return geometric_mean

    @staticmethod
    def _geometric_mean_nan(array):
        list_of = []
        for equiv_layer_values in array.transpose():
            list_of.append(gmean(equiv_layer_values[~np.isnan(equiv_layer_values)]))
        return np.array(list_of)

    # all function for noise calculation

    def calc_equiv_noise(self, factors_vector):
        # this function is the main optimization function we want to minimize
        # for all the notations you may look at https://hailotech.atlassian.net/wiki/spaces/ML/pages/edit-v2/826081518
        # which is a modified version of the calc noise in https://arxiv.org/pdf/1902.01917.pdf
        inverse_sqnr = 0
        assert len(factors_vector) == self.number_of_factors
        for consumer in self.consumers:
            noise_cons = self._calc_per_layer_consumer(consumer, factors_vector)
            inverse_sqnr += noise_cons

        return inverse_sqnr

    def y1_y1_noise(self, producer, factors_vector):
        stats_by_layer = self._get_data_for_equiv_layer(producer)

        axes_to_max = np.array(stats_by_layer["axes_to_max"], dtype=np.int32)
        axis = int(np.where(axes_to_max == 0)[0])

        padded_factors = self._get_padded_factor(producer, factors_vector, stats_by_layer["kernel"].shape, axis)
        max_layer = np.max(stats_by_layer["post_activation_max"] * padded_factors)
        min_layer = np.min(stats_by_layer["post_activation_min"] * padded_factors)
        range_of = max_layer - min_layer

        activation_noise = self._calculate_noise_by_range(range_of, 8)
        if producer.layer.activation == ActivationType.relu and self.RELU_SENSITIVE:
            non_zero_percent = stats_by_layer["non_zero_percent"]
            activation_noise *= non_zero_percent

        return activation_noise

    def y1_w1_noise(self, producer, factors_vector):
        to_compute, w1, x1_energy, bits, _ = self._get_kernel_and_input_info_by_factors(producer, factors_vector)
        mean_x1_energy = np.mean(x1_energy)
        noise_weight_w1 = self._calculate_noise_weights(w1, bits)
        return to_compute * mean_x1_energy * noise_weight_w1

    def y_2_w_1_noise(self, producer, w2_sum, factors_vector, consumer_output_features):
        y1_w1_noise = self.y1_w1_noise(producer, factors_vector)
        return (w2_sum * y1_w1_noise) / consumer_output_features

    def y_2_y_1_noise(self, producer, w2_sum, factors_vector, consumer_output_features):
        noise_activation_y1 = self.y1_y1_noise(producer, factors_vector)
        return (w2_sum * noise_activation_y1) / consumer_output_features

    def y_2_w_2_noise(self, to_compute, w2, x2_energy, bits):
        mean_x2_energy = np.mean(x2_energy)
        noise_weight_w2 = self._calculate_noise_weights(w2, bits)

        return to_compute * mean_x2_energy * noise_weight_w2

    @staticmethod
    def _kernel_sum(kernel):
        return np.sum(kernel**2)

    @staticmethod
    def _get_padded_factor(equiv_layer, factors, shape_kernel, axis):
        indices = np.array(equiv_layer.layer_indices)
        padded_factors = np.ones(shape_kernel[axis])
        padded_factors[indices] = factors[equiv_layer.set_indices]
        return padded_factors

    def _get_algorithm_statistics(self, factors_solution, compared_solution=None):
        if compared_solution is None:
            compared_solution = np.ones(len(factors_solution))
        for producer in self.producers:
            (
                y1_w1_noise,
                y1_y1_noise,
                signal_prod,
                noise_prod,
                sqnr_prod,
                sqnr_prod_y1_w1,
                sqnr_prod_y1_y1,
            ) = self._wrapper_for_calc_layer_prod(producer, factors_solution)
            (
                y1_w1_noise_no_algo,
                y1_y1_noise_no_algo,
                signal_prod_no_algo,
                noise_prod_no_algo,
                sqnr_prod_no_algo,
                sqnr_prod_y1_w1_no_algo,
                sqnr_prod_y1_y1_no_algo,
            ) = self._wrapper_for_calc_layer_prod(producer, compared_solution)

            default_logger().debug(
                f"for producer: {producer.layer_name} the is signal [{signal_prod} , {signal_prod_no_algo}]",
            )
            default_logger().debug(
                f"for producer: {producer.layer_name} snr [{sqnr_prod:.2f}, {sqnr_prod_no_algo:.2f}] noise  "
                f"[{noise_prod} , {noise_prod_no_algo}]",
            )
            default_logger().debug(
                f"for producer: {producer.layer_name} snr_y1_w1 [{sqnr_prod_y1_w1:.2f}, {sqnr_prod_y1_w1_no_algo:.2f}] "
                f"noise_y1_w1 [{y1_w1_noise} , {y1_w1_noise_no_algo}]",
            )
            default_logger().debug(
                f"for producer: {producer.layer_name} snr_y1_y1 [{sqnr_prod_y1_y1:.2f}, {sqnr_prod_y1_y1_no_algo:.2f}] "
                f"noise_y1_y1 [{y1_y1_noise} , {y1_y1_noise_no_algo}]",
            )

        for consumer in self.consumers:
            (
                noise_cons,
                signal_cons,
                dict_debugging,
                y2_w2_noise,
                sqnr_cons,
            ) = self._calc_data_layer_consumer_for_statistics(consumer, factors_solution)
            (
                noise_cons_no_algo,
                signal_cons_no_algo,
                dict_debugging_no_algo,
                y2_w2_noise_no_algo,
                sqnr_cons_no_algo,
            ) = self._calc_data_layer_consumer_for_statistics(consumer, compared_solution)

            default_logger().debug(
                f"for consumer: {consumer.layer_name} the is signal [{signal_cons} , {signal_cons_no_algo}]",
            )
            default_logger().debug(
                f"for consumer: {consumer.layer_name} snr [{sqnr_cons:.2f}, {sqnr_cons_no_algo:.2f}] noise  "
                f"[{noise_cons}, {noise_cons_no_algo}]",
            )

            default_logger().debug(
                f"on: {consumer.layer_name} y2_w2_noise  [{y2_w2_noise:.2f}, {y2_w2_noise_no_algo:.2f}]",
            )

            for layer in dict_debugging:
                for key in dict_debugging[layer]:
                    default_logger().debug(
                        f"for prod {layer} {key} is [{dict_debugging[layer][key]:.2f}, "
                        f"{dict_debugging_no_algo[layer][key]:.2f}]",
                    )

    def _wrapper_for_calc_layer_prod(self, producer, factors):
        y1_w1_noise, y1_y1_noise, signal_prod = self._calc_per_layer_producer(producer, factors)
        noise_prod = y1_w1_noise + y1_y1_noise
        sqnr_prod = 10 * np.log10(signal_prod / noise_prod)
        sqnr_prod_y1_w1 = 10 * np.log10(signal_prod / y1_w1_noise)
        sqnr_prod_y1_y1 = 10 * np.log10(signal_prod / y1_y1_noise)

        return y1_w1_noise, y1_y1_noise, signal_prod, noise_prod, sqnr_prod, sqnr_prod_y1_w1, sqnr_prod_y1_y1

    @staticmethod
    def _calculate_noise_by_range(range_of, bits):
        up_c = (range_of / (2**bits)) ** 2
        return up_c / 12

    def _calculate_noise_weights(self, w, bits):
        max_layer = np.max(np.abs(w))
        range_of = 2 * max_layer
        return self._calculate_noise_by_range(range_of, bits)

    def _get_kernel_and_input_info_by_factors(self, equiv_layer, factors_vector):
        stats_by_layer = self._get_data_for_equiv_layer(equiv_layer)
        layer = equiv_layer.layer
        w = deepcopy(stats_by_layer["kernel"])
        x = deepcopy(stats_by_layer["input_energy"])
        bits = stats_by_layer["number_bits"]
        axes_to_max = stats_by_layer["axes_to_max"]
        axes_mask = np.array(axes_to_max, dtype=np.int32)
        axis = int(np.where(axes_mask == 0)[0])
        axes_mask[axes_mask == 0] = -1
        output_features = equiv_layer.layer.output_features
        pixel_batch_scale = int(np.array(w.shape).prod() / output_features)
        if layer.op in [LayerType.conv, LayerType.deconv] and layer.groups > 1:
            w1 = stats_by_layer["kernel_before"]
            pixel_batch_scale = int(np.array(w1.shape).prod() / output_features)
        if equiv_layer.is_consumer():
            factors_to_use = old_div(1, factors_vector)
        elif equiv_layer.is_producer():
            factors_to_use = factors_vector
        else:
            raise FactorsCalculatorError(
                f"The noise calculation expects a producer or "
                f"consumer but {equiv_layer.layer.name} is of type None.",
            )
        padded_factor = self._get_padded_factor(equiv_layer, factors_to_use, w.shape, axis)
        kernels_factors = np.reshape(padded_factor, axes_mask)
        if equiv_layer.is_consumer():
            padded_factor_for_input = self._get_padded_factor(equiv_layer, factors_vector, w.shape, axis)
            x *= padded_factor_for_input**2

        w = w * kernels_factors
        return pixel_batch_scale, w, x, bits, output_features

    def _calc_noise_of_specific_prod_cons(self, producer, w2_sum, factors_vector, output_features_consumers):
        y2_w_1_noise = self.y_2_w_1_noise(producer, w2_sum, factors_vector, output_features_consumers)
        y_2_y_1_noise = self.y_2_y_1_noise(producer, w2_sum, factors_vector, output_features_consumers)
        return y2_w_1_noise + y_2_y_1_noise

    def _get_output_energy(self, equiv_layer, factors_vector=None):
        stats_by_layer = self._get_data_for_equiv_layer(equiv_layer)
        output = deepcopy(stats_by_layer["output_energy"])
        w = stats_by_layer["kernel"]
        axes_to_max = np.array(stats_by_layer["axes_to_max"], dtype=np.int32)
        axis = int(np.where(axes_to_max == 0)[0])
        if equiv_layer.is_producer():
            padded_factor_output = self._get_padded_factor(equiv_layer, factors_vector, w.shape, axis)
            output *= padded_factor_output**2
        return output

    def _calc_per_layer_consumer(self, consumer, factors_vector):
        noise_cons, signal_cons = self._calc_data_layer_consumer(consumer, factors_vector)
        if signal_cons == 0:
            return 0
        return noise_cons / signal_cons

    def _calc_data_layer_consumer(self, consumer, factors_vector):
        to_compute, w2, x2_energy, bits, output_features_consumers = self._get_kernel_and_input_info_by_factors(
            consumer,
            factors_vector,
        )
        w2_sum = self._kernel_sum(w2)
        noise_cons = self.y_2_w_2_noise(to_compute, w2, x2_energy, bits)
        for producer in consumer.prev_producers:
            noise_cons += self._calc_noise_of_specific_prod_cons(
                producer,
                w2_sum,
                factors_vector,
                output_features_consumers,
            )
        y2_energy = self._get_output_energy(consumer)
        signal_cons = np.mean(y2_energy)
        return noise_cons, signal_cons

    def _calc_data_layer_consumer_for_statistics(self, consumer, factors_vector):
        dict_prod = {}
        noise_cons = 0
        to_compute, w2, x2_energy, bits, consumer_feature_out = self._get_kernel_and_input_info_by_factors(
            consumer,
            factors_vector,
        )
        w2_sum = self._kernel_sum(w2)
        y2_w2_noise = self.y_2_w_2_noise(to_compute, w2, x2_energy, bits)
        noise_cons += y2_w2_noise
        for producer in consumer.prev_producers:
            dict_prod[producer.layer_name] = {}
            y2_w1_noise = self.y_2_w_1_noise(producer, w2_sum, factors_vector, consumer_feature_out)
            y2_y1_noise = self.y_2_y_1_noise(producer, w2_sum, factors_vector, consumer_feature_out)
            noise_producer_cons = y2_w1_noise + y2_y1_noise
            dict_prod[producer.layer_name]["y2_w1_noise"] = y2_w1_noise
            dict_prod[producer.layer_name]["y2_y1_noise"] = y2_y1_noise
            dict_prod[producer.layer_name]["the_sum"] = noise_producer_cons
            noise_cons += noise_producer_cons
        y2_energy = self._get_output_energy(consumer)
        signal_cons = np.mean(y2_energy)
        sqnr_cons = 10 * np.log10(signal_cons / noise_cons)
        return noise_cons, signal_cons, dict_prod, y2_w2_noise, sqnr_cons

    def _calc_per_layer_producer(self, producer, factors_vector):
        assert len(factors_vector) == self.number_of_factors
        y1_w1_noise = self.y1_w1_noise(producer, factors_vector)
        y1_y1_noise = self.y1_y1_noise(producer, factors_vector)
        y1_energy = self._get_output_energy(producer, factors_vector)
        energy_mean = np.mean(y1_energy)
        return y1_w1_noise, y1_y1_noise, energy_mean
