from collections import namedtuple
from functools import reduce

import numpy as np

from hailo_model_optimization.acceleras.hailo_layers.hailo_layer_norm import HailoLayerNorm
from hailo_model_optimization.acceleras.hailo_layers.hailo_layer_norm_mercury import HailoLayerNormMercury
from hailo_model_optimization.acceleras.hailo_layers.hailo_layer_normalization import HailoLayerNormalization
from hailo_model_optimization.acceleras.hailo_layers.hailo_precision_split import HailoPrecisionSplit
from hailo_model_optimization.acceleras.hailo_layers.op_factories import gen_acceleras_layers_from_hn
from hailo_model_optimization.acceleras.lossy_elements.quant_element import APUOutputQuantElement
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_base import CommandMeta
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerAdaRoundConfig,
    LayerBiasCorrectionConfig,
    LayerEqualizationConfig,
    LayerNegExponentConfig,
    LayerPrecisionConfig,
    LayerTranslationConfig,
    LayerZeroStaticChannelsConfig,
)
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    LayerNormMode,
    OptimizationTarget,
    PrecisionMode,
    ThreeWayPolicy,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasImplementationError
from hailo_model_optimization.algorithms.mixed_precision.create_mixed_precision import CreateMixedPrecision
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm
from hailo_model_optimization.algorithms.stats_collection.stats_collection import StatsCollector

SUPPORTED_LAYERS = [HailoLayerNormalization]
DEBUG = False
EquivClassNorm = namedtuple("EquivClassNorm", ["source", "consumer_square", "consumer_out"])
MaskExpDecomposeClassNorm = namedtuple("MaskExpDecomposeClassNorm", ["exp_decompose", "shift"])


class DecomposeLayerNorm(OptimizationAlgorithm):
    def __init__(
        self,
        model: HailoModel,
        model_config,
        logger_level,
        dataset,
        **kwargs,
    ):
        super().__init__(model, model_config, logger_level=logger_level, name="LayerNorm Decomposition", **kwargs)
        self._unbatched_dataset = dataset
        self._logger_level_other = 9

    def _setup(self):
        config = self.get_algo_config()
        self._mode = config.mode
        self._equalization = config.equalization == ThreeWayPolicy.enabled
        self._nudging = config.nudging
        self._group_nudging = config.group_nudging
        self._square_12_bit = config.square_12_bit
        self._optimize_ew_mult = config.equalization and self._square_12_bit
        self._token_equalization = config.token_equalization == ThreeWayPolicy.enabled
        self._equalization_info_by_layer = dict()
        self._mask_exp_decompose_info_by_layer = dict()
        self._precision_split_layers = set()
        self._verify()
        if DEBUG:
            self._logger.info(f"self.optimization_target {self.optimization_target}")
            self._logger.info(f"self._mode {self._mode}")
            self._logger.info(f"self._equalization  {self._equalization}")
            self._logger.info(f"self._nudging {self._nudging}")
            self._logger.info(f"self._group_nudging {self._group_nudging} ")
            self._logger.info(f"self._square_12_bit {self._square_12_bit}")
            self._logger.info(f"self._optimize_ew_mult {self._optimize_ew_mult}")
            self._logger.info(f"self._token_equalization {self._token_equalization}")

    def _verify(self):
        if self.optimization_target == OptimizationTarget.SAGE and self._mode == LayerNormMode.ppu:
            raise AccelerasImplementationError("SAGE does not support PPU mode")
        all_16bit_inp = np.all([self._is_16_bit_inp(layer) for layer in self.norm_layers])
        if self._mode == LayerNormMode.ppu and not all_16bit_inp:
            raise AccelerasImplementationError("PPU mode must have all layers as 16 bit input")

    def should_skip_algo(self):
        return len(self.norm_layers) == 0

    def get_algo_config(self):
        return self._model_config.layer_norm_decomposition

    def _run_int(self):
        nn_core = self._mode == LayerNormMode.nn_core

        if self.optimization_target == OptimizationTarget.SAGE or nn_core:
            self._run_int_core()
        elif self.optimization_target == OptimizationTarget.MERCURY:
            self._run_int_mercury()
        elif self.optimization_target == OptimizationTarget.PLUTO:
            self._run_int_pluto()

    def _run_int_mercury(self):
        for layer in self.norm_layers:
            self._decompose_mercury(layer)

    def _run_int_pluto(self):
        for layer in self.norm_layers:
            self._decompose_pluto(layer)

        algo = CreateMixedPrecision(
            model=self._model,
            model_config=self._model_config,
            logger_level=self._logger_level_other,
            logger=self._logger,
        )
        algo.run()

    def _run_int_core(self):
        # first run the algo on all layer norms
        for layer in self.norm_layers:
            self.decompose_single_layer_norm(layer)
        algo = CreateMixedPrecision(
            model=self._model,
            model_config=self._model_config,
            logger_level=self._logger_level_other,
            logger=self._logger,
        )
        algo.run()

        if self._run_statistics:
            stats_collector = StatsCollector(
                self._model,
                self._model_config,
                self._logger_level_other,
                self._unbatched_dataset,
                logger=self._logger,
                layers_to_handle=self._layers_to_collect,
            )
            stats_collector.run()
            for lname in self._equalization_info_by_layer:
                self.equalize_layer(lname)
            for lname in self._precision_split_layers:
                self._split_layer(self._model, lname)
            for lname in self._mask_exp_decompose_info_by_layer:
                self.mask_exp_decompose(lname)

    @property
    def _run_statistics(self):
        return (
            len(self._equalization_info_by_layer)
            + len(self._precision_split_layers)
            + len(self._mask_exp_decompose_info_by_layer)
        ) > 0

    @property
    def norm_layers(self):
        res = []
        for layer in self._model.flow.toposort():
            acceleras_layer = self._model.layers[layer]
            if type(acceleras_layer) in SUPPORTED_LAYERS:
                res.append(acceleras_layer)
        return res

    @property
    def _layers_to_collect(self):
        layers_to_collect = set()
        for lname, equiv_set in self._equalization_info_by_layer.items():
            layers_to_collect.add(lname)
            layers_to_collect.add(equiv_set.source)
            layers_to_collect.add(equiv_set.consumer_square)
            layers_to_collect.add(equiv_set.consumer_out)
        for lname in self._precision_split_layers:
            layers_to_collect.add(lname)
        for lname in self._mask_exp_decompose_info_by_layer:
            layers_to_collect.add(lname)
        return layers_to_collect

    def finalize_global_cfg(self, algo_config):
        if algo_config.equalization == ThreeWayPolicy.allowed:
            algo_config.equalization = ThreeWayPolicy.enabled

    def _get_valid_layer_cfg(self, lname, cfg):
        return cfg

    @staticmethod
    def _is_16_bit_inp(layer_norm):
        value = layer_norm.get_precision_mode()
        return value in [PrecisionMode.a16_w16_a16, PrecisionMode.a16_w16_a8, PrecisionMode.a16_w16]

    def decompose_single_layer_norm(self, layer_norm):
        """
        decompose single layer norm by its precision mode and other parameters
        """
        # this function will be called when we enable equalization of 8 bit
        if self._is_16_bit_inp(layer_norm):
            self.decompose_layer_norm_16bit_input(layer_norm)
        else:
            # TODO enable equalization for 8 bit input (decompose_layer_norm_8bit_input)
            self._decompose_layer_norm_8bit_input_no_equalization(layer_norm)

    def _remove_equalization_on_input_layer_norm(self, layer_norm):
        # TODO: iterate backwards to find the source layers before the layer norm
        self._model_config.equalization.layers[self._model.flow.predecessors_sorted(layer_norm.full_name)[0]] = (
            LayerEqualizationConfig(policy="disabled")
        )

    def create_mean_resize_ew_sub(self, layer_norm, block_name, layer_norm_name, layer_norm_shape, precision_mode):
        """
        create mean and resize layer back to normal
        """
        bias_mode = "single_scale_decomposition"
        groups = layer_norm.groups
        reduce_mean1_name = f"{block_name}reduce_mean1_{layer_norm_name}"
        ew_sub_all_name = f"{block_name}ew_sub_all_{layer_norm_name}"
        resize1_name = f"{block_name}resize1_{layer_norm_name}"
        reduce_mean_spatial_name = f"{block_name}reduce_mean_spatial_{layer_norm_name}"

        output_names = [reduce_mean1_name, ew_sub_all_name]
        pred_layer_norm = self.update_pred_layer(layer_norm, output_names)

        bias_zero = np.array(0)

        output_reduce_mean1 = resize1_name if groups == 1 else reduce_mean_spatial_name
        if groups > 1:
            reduce_mean1 = self.add_reduce_mean_layer(
                reduce_mean1_name,
                [output_reduce_mean1],
                [pred_layer_norm],
                bias_zero,
                "linear",
                groups=groups,
                bias_mode=bias_mode,
                precision_mode=precision_mode,
            )
        else:
            reduce_mean1 = self.add_conv_mean_layer(
                reduce_mean1_name,
                [output_reduce_mean1],
                [pred_layer_norm],
                bias_zero,
                "linear",
                factors=(1,),
                bias_mode=bias_mode,
                precision_mode=precision_mode,
            )
        prev_resize1 = reduce_mean1
        if groups > 1:
            resize_spatial_name = f"{block_name}resize_spatial_{layer_norm_name}"
            reduce_mean_spatial = self.add_spatial_reduce_mean_layer(
                reduce_mean_spatial_name,
                [resize_spatial_name],
                [reduce_mean1],
                bias_zero,
                "linear",
                bias_mode=bias_mode,
                precision_mode=precision_mode,
            )

            resize_spatial = self.add_resize_layer(
                resize_spatial_name,
                [resize1_name],
                [reduce_mean_spatial],
                layer_norm_shape,
                bias_mode=bias_mode,
                precision_mode=precision_mode,
                channels=False,
            )
            prev_resize1 = resize_spatial

        resize1 = self.add_resize_layer(
            resize1_name,
            [ew_sub_all_name],
            [prev_resize1],
            layer_norm_shape,
            bias_mode=bias_mode,
            precision_mode=precision_mode,
        )

        ew_sub_all = self.add_ew_sub_layer(
            ew_sub_all_name,
            [f"{block_name}precision_split_{layer_norm_name}"],
            [pred_layer_norm, resize1],
            bias_mode=bias_mode,
            precision_mode=precision_mode,
        )
        return ew_sub_all

    def create_token_equalization(self, layer_norm, pred_layer, block_name, layer_norm_name, output_names):
        groups = layer_norm.groups

        exp_decompose_name = f"{block_name}exp_decompose_{layer_norm_name}"
        reduce_max_name = f"{block_name}reduce_max1_{layer_norm_name}"
        reduce_max_spatial_name = f"{block_name}reduce_max_spatial_{layer_norm_name}" if groups > 1 else reduce_max_name
        shift_name = f"{block_name}shift_{layer_norm_name}"

        exp_decompose = self.add_standalone_activation(
            exp_decompose_name,
            [reduce_max_spatial_name],
            [pred_layer],
            activation="exp_decompose",
            precision_mode="a16_w16_a16",
        )

        if groups > 1:
            reduce_max_input = self.add_reduce_max(
                reduce_max_spatial_name,
                [reduce_max_name],
                [exp_decompose],
                [1, 2],
                groups=1,
                precision_mode="a16_w16_a16",
            )
        else:
            reduce_max_input = exp_decompose
        reduce_max = self.add_reduce_max(
            reduce_max_name, [shift_name], [reduce_max_input], [3], groups=groups, precision_mode="a16_w16_a16"
        )

        shift = self.add_shift_layer(shift_name, output_names, [pred_layer, reduce_max], precision_mode="a16_w16_a16")

        self._mask_exp_decompose_info_by_layer[exp_decompose.full_name] = MaskExpDecomposeClassNorm(
            exp_decompose.full_name, shift.full_name
        )

        return shift

    def decompose_layer_norm_16bit_input(self, layer_norm):
        """
        decompose layer norm  and split inputs into 2 layers  each of 8 and 7 bits respectively - LSB and MSB.
        """
        layer_norm_name = layer_norm.full_name.split("/")[1]
        # splits block name and layer name
        block_name, layer_norm_name = self.get_block_and_layer_names(layer_norm_name)
        layer_norm_shape = layer_norm.to_hn()["input_shapes"][0]  # the output_shape of the layer norm
        precision_mode = "a16_w16"
        groups = layer_norm.groups
        if not layer_norm.rms_norm:
            pred_equalization_source = self.create_mean_resize_ew_sub(
                layer_norm, block_name, layer_norm_name, layer_norm_shape, precision_mode
            )
            resize_name = f"{block_name}resize2_{layer_norm_name}"
        else:
            self._remove_equalization_on_input_layer_norm(layer_norm)
            pred_equalization_source = self.update_pred_layer(
                layer_norm, [f"{block_name}equalization_source_{layer_norm_name}"]
            )
            resize_name = f"{block_name}resize1_{layer_norm_name}"

        if self._token_equalization:
            pred_equalization_source = self.create_token_equalization(
                layer_norm,
                pred_equalization_source,
                block_name,
                layer_norm_name,
                [f"{block_name}equalization_source_{layer_norm_name}"],
            )

        equalization_source = self.add_normalization_layer(
            f"{block_name}equalization_source_{layer_norm_name}",
            [f"{block_name}precision_split_{layer_norm_name}"],
            [pred_equalization_source],
        )

        splitter_layer = self._add_precision_split_layer(
            f"{block_name}precision_split_{layer_norm_name}",
            [f"{block_name}shortcut1_low_{layer_norm_name}", f"{block_name}shortcut2_high_{layer_norm_name}"],
            [equalization_source],
        )
        if self._square_12_bit:
            lossy_12_bit = APUOutputQuantElement(bits=12)
            equalization_source.output_lossy_element_external = lossy_12_bit
            splitter_layer.input_lossy_element_external = lossy_12_bit

        if self._optimize_ew_mult:
            mock_kernel_values_square_low = [2, 2]
            mock_kernel_values_square_high = [8, 8]
            mock_kernel_values_low_high_mult = [2, 8]
            mock_kernel_values_mult_low_var = [2, 2]
            mock_kernel_values_mult_high_var = [2, 32]
        else:
            mock_kernel_values_square_low = [2, 2]
            mock_kernel_values_square_high = [2, 2]
            mock_kernel_values_low_high_mult = [2, 2]
            mock_kernel_values_mult_low_var = [2, 2]
            mock_kernel_values_mult_high_var = [2, 2]

        x_low = self.add_shortcut_layer(
            f"{block_name}shortcut1_low_{layer_norm_name}",
            [f"{block_name}square_low_{layer_norm_name}", f"{block_name}ew_mult_low_high_{layer_norm_name}"],
            [splitter_layer],
            output_index=0,
        )
        x_high = self.add_shortcut_layer(
            f"{block_name}shortcut2_high_{layer_norm_name}",
            [f"{block_name}square_high_{layer_norm_name}", f"{block_name}ew_mult_low_high_{layer_norm_name}"],
            [splitter_layer],
            output_index=1,
        )

        square_low = self.add_square_layer(
            f"{block_name}square_low_{layer_norm_name}",
            [f"{block_name}precision_change_low_{layer_norm_name}"],
            [x_low],
            mock_kernel_values=mock_kernel_values_square_low,
        )
        precision_change_low = self._add_precision_change(
            f"{block_name}precision_change_low_{layer_norm_name}",
            [f"{block_name}concat_layer_{layer_norm_name}"],
            [square_low],
        )

        square_high = self.add_square_layer(
            f"{block_name}square_high_{layer_norm_name}",
            [f"{block_name}precision_change_high_{layer_norm_name}"],
            [x_high],
            mock_kernel_values=mock_kernel_values_square_high,
        )
        precision_change_high = self._add_precision_change(
            f"{block_name}precision_change_high_{layer_norm_name}",
            [f"{block_name}concat_layer_{layer_norm_name}"],
            [square_high],
        )

        ew_mult_low_high = self.add_ew_mult_layer(
            f"{block_name}ew_mult_low_high_{layer_norm_name}",
            [f"{block_name}precision_change_low_high_{layer_norm_name}"],
            [x_low, x_high],
            mock_kernel_values=mock_kernel_values_low_high_mult,
        )
        precision_change_low_high = self._add_precision_change(
            f"{block_name}precision_change_low_high_{layer_norm_name}",
            [f"{block_name}concat_layer_{layer_norm_name}"],
            [ew_mult_low_high],
        )

        # set the ratio between input scale and output scale for relevant layers
        #################
        HIGH_FACTOR = 0
        LOW_FACTOR = 6
        HIGH_LOW_FACTOR = 4

        square_low.forced_output_scale_scalar_dof = 2**LOW_FACTOR
        square_high.forced_output_scale_scalar_dof = 2**HIGH_FACTOR
        ew_mult_low_high.forced_output_scale_scalar_dof = 2**HIGH_LOW_FACTOR

        precision_change_low.forced_output_scale_scalar_dof = 1
        precision_change_high.forced_output_scale_scalar_dof = 1
        precision_change_low_high.forced_output_scale_scalar_dof = 1
        #####################

        weights = layer_norm.export_weights()
        epsilon = weights["epsilon"]

        out_conv_var_name = f"{block_name}spatial_mean_var_inv_{layer_norm_name}" if groups > 1 else resize_name
        prev_resize_all, equalization_consumer_square = self._concatenate_and_sum(
            f"{block_name}{layer_norm_name}",
            epsilon,
            [precision_change_low, precision_change_high, precision_change_low_high],
            out_conv_var_name,
            norm_groups=groups,
        )

        if groups > 1:
            resize_spatial_var_name = f"{block_name}resize_spatial_var_{layer_norm_name}"
            bias_mode = "single_scale_decomposition"
            epsilon = np.array(epsilon)
            reduce_mean_spatial = self.add_spatial_reduce_mean_layer(
                out_conv_var_name,
                [resize_spatial_var_name],
                [prev_resize_all],
                epsilon,
                "inv_sqrt",
                bias_mode=bias_mode,
                precision_mode=precision_mode,
            )

            resize_spatial_var = self.add_resize_layer(
                resize_spatial_var_name,
                [resize_name],
                [reduce_mean_spatial],
                layer_norm_shape,
                bias_mode=bias_mode,
                precision_mode=precision_mode,
                channels=False,
            )
            prev_resize_all = resize_spatial_var

        resize_all = self.add_resize_layer(
            resize_name,
            [f"{block_name}ew_mult_low_var_{layer_norm_name}", f"{block_name}ew_mult_high_var_{layer_norm_name}"],
            [prev_resize_all],
            layer_norm_shape,
        )

        ew_mult_low_var = self.add_ew_mult_layer(
            f"{block_name}ew_mult_low_var_{layer_norm_name}",
            [f"{block_name}ew_add_out_{layer_norm_name}"],
            [resize_all, x_low],
            mock_kernel_values=mock_kernel_values_mult_low_var,
        )
        ew_mult_high_var = self.add_ew_mult_layer(
            f"{block_name}ew_mult_high_var_{layer_norm_name}",
            [f"{block_name}ew_add_out_{layer_norm_name}"],
            [resize_all, x_high],
            mock_kernel_values=mock_kernel_values_mult_high_var,
        )

        ew_add_out = self.add_ew_add_layer(
            f"{block_name}ew_add_out_{layer_norm_name}",
            [f"{block_name}equalization_consumer_out_{layer_norm_name}"],
            [ew_mult_low_var, ew_mult_high_var],
        )

        _, _, output_layer_names = self._get_output_properties(layer_norm)
        equalization_consumer_out = self.add_normalization_layer(
            f"{block_name}equalization_consumer_out_{layer_norm_name}",
            output_layer_names,
            [ew_add_out],
        )

        self._set_new_output_of_layer(layer_norm, equalization_consumer_out)
        if self._equalization:
            self._equalization_info_by_layer[splitter_layer.full_name] = EquivClassNorm(
                equalization_source.full_name,
                equalization_consumer_square.full_name,
                equalization_consumer_out.full_name,
            )
        else:
            self._precision_split_layers.add(splitter_layer.full_name)

        self._remove_layer(layer_norm)

    def decompose_layer_norm_8bit_input(self, layer_norm):
        # still wip
        if self._equalization:
            self._decompose_layer_norm_8bit_input_with_equalization(layer_norm)
        else:
            self._decompose_layer_norm_8bit_input_no_equalization(layer_norm)

    def _decompose_layer_norm_8bit_input_with_equalization(self, layer_norm, factor=64):
        layer_norm_name = layer_norm.full_name.split("/")[1]
        # splits block name and layer name
        block_name, layer_norm_name = self.get_block_and_layer_names(layer_norm_name)
        layer_norm_shape = layer_norm.to_hn()["input_shapes"][0]

        bias_mode = "double_scale_initialization"
        precision_mode = "a8_w8"

        if not layer_norm.rms_norm:
            pred_layer_norm = self.update_pred_layer(
                layer_norm,
                [f"{block_name}reduce_mean1_{layer_norm_name}", f"{block_name}ew_sub1_{layer_norm_name}"],
            )
            reduce_mean1 = self.add_reduce_mean_layer(
                f"{block_name}reduce_mean1_{layer_norm_name}",
                [f"{block_name}resize1_{layer_norm_name}"],
                [pred_layer_norm],
                bias_mode=bias_mode,
                precision_mode=precision_mode,
            )
            resize1 = self.add_resize_layer(
                f"{block_name}resize1_{layer_norm_name}",
                [f"{block_name}ew_sub1_{layer_norm_name}"],
                [reduce_mean1],
                layer_norm_shape,
                bias_mode=bias_mode,
                precision_mode=precision_mode,
            )
            ew_sub1 = self.add_ew_sub_layer(
                f"{block_name}ew_sub1_{layer_norm_name}",
                [f"{block_name}square1_{layer_norm_name}"],
                [pred_layer_norm, resize1],
                bias_mode=bias_mode,
                precision_mode=precision_mode,
            )
            equalization_source = self.add_normalization_layer(
                f"{block_name}equalization_source_{layer_norm_name}",
                [f"{block_name}square1_{layer_norm_name}", f"{block_name}ew_mult1_{layer_norm_name}"],
                [ew_sub1],
                bias_mode=bias_mode,
                precision_mode=precision_mode,
            )
            resize_name = f"{block_name}resize2_{layer_norm_name}"

        else:
            pred_layer_norm = self.update_pred_layer(layer_norm, [f"{block_name}equalization_source_{layer_norm_name}"])
            equalization_source = self.add_normalization_layer(
                f"{block_name}equalization_source_{layer_norm_name}",
                [f"{block_name}square1_{layer_norm_name}", f"{block_name}ew_mult1_{layer_norm_name}"],
                [pred_layer_norm],
                bias_mode=bias_mode,
                precision_mode=precision_mode,
            )
            resize_name = f"{block_name}resize1_{layer_norm_name}"
        square1 = self.add_square_layer(
            f"{block_name}square1_{layer_norm_name}",
            [f"{block_name}conv_var_inv_{layer_norm_name}"],
            [equalization_source],
        )

        weights = layer_norm.export_weights()
        epsilon = weights["epsilon"]

        if self._nudging:
            factors = (1 * factor,)
            normalization_nudge = self.add_normalization_layer(
                f"{block_name}normalization_nudge_{layer_norm_name}",
                [f"{block_name}can_var_inv_{layer_norm_name}"],
                [square1],
                factor=1 / factor,
            )
            input_to_reduce = normalization_nudge
        else:
            factors = (1,)
            input_to_reduce = square1

        can_var = self.add_conv_mean_layer(
            f"{block_name}conv_var_inv_{layer_norm_name}",
            [resize_name],
            [input_to_reduce],
            epsilon,
            "inv_sqrt",
            factors=factors,
        )
        resize_all = self.add_resize_layer(
            resize_name,
            [f"{block_name}ew_mult1_{layer_norm_name}"],
            [can_var],
            layer_norm_shape,
        )
        ew_mult1 = self.add_ew_mult_layer(
            f"{block_name}ew_mult1_{layer_norm_name}",
            [f"{block_name}equalization_consumer_out_{layer_norm_name}"],
            [resize_all, equalization_source],
        )

        _, _, output_layer_names = self._get_output_properties(layer_norm)
        equalization_consumer_out = self.add_normalization_layer(
            f"{block_name}equalization_consumer_out_{layer_norm_name}",
            output_layer_names,
            [ew_mult1],
        )
        self._set_new_output_of_layer(layer_norm, equalization_consumer_out)

        equalization_consumer_square = normalization_nudge if self._nudging else can_var
        self._equalization_info_by_layer[equalization_source.full_name] = EquivClassNorm(
            equalization_source.full_name,
            equalization_consumer_square.full_name,
            equalization_consumer_out.full_name,
        )
        self._remove_layer(layer_norm)

    def _decompose_layer_norm_8bit_input_no_equalization(self, layer_norm):
        layer_norm_name = layer_norm.full_name.split("/")[1]
        # splits block name and layer name
        block_name, layer_norm_name = self.get_block_and_layer_names(layer_norm_name)
        shape_layer_norm = layer_norm.to_hn()["input_shapes"][0]
        if not layer_norm.rms_norm:
            pred_layer_norm = self.update_pred_layer(
                layer_norm,
                [f"{block_name}reduce_mean1_{layer_norm_name}", f"{block_name}ew_sub1_{layer_norm_name}"],
            )
            reduce_mean1 = self.add_reduce_mean_layer(
                f"{block_name}reduce_mean1_{layer_norm_name}",
                [f"{block_name}resize1_{layer_norm_name}"],
                [pred_layer_norm],
            )
            resize1 = self.add_resize_layer(
                f"{block_name}resize1_{layer_norm_name}",
                [f"{block_name}ew_sub1_{layer_norm_name}"],
                [reduce_mean1],
                shape_layer_norm,
            )
            ew_sub1 = self.add_ew_sub_layer(
                f"{block_name}ew_sub1_{layer_norm_name}",
                [f"{block_name}square1_{layer_norm_name}"],
                [pred_layer_norm, resize1],
            )
            layer_to_square = ew_sub1
            resize_name = f"{block_name}resize2_{layer_norm_name}"
            reduce_mean_name = f"{block_name}reduce_mean2_{layer_norm_name}"
        else:
            pred_layer_norm = self.update_pred_layer(
                layer_norm,
                [f"{block_name}square1_{layer_norm_name}", f"{block_name}ew_mult1_{layer_norm_name}"],
            )
            layer_to_square = pred_layer_norm
            resize_name = f"{block_name}resize1_{layer_norm_name}"
            reduce_mean_name = f"{block_name}reduce_mean1_{layer_norm_name}"

        square1 = self.add_square_layer(f"{block_name}square1_{layer_norm_name}", [reduce_mean_name], [layer_to_square])
        weights = layer_norm.export_weights()
        epsilon = weights["epsilon"]

        reduce_mean2 = self.add_reduce_mean_layer(
            reduce_mean_name,
            [resize_name],
            [square1],
            epsilon=epsilon,
            activation="inv_sqrt",
        )
        resize2 = self.add_resize_layer(
            resize_name,
            [f"{block_name}ew_mult1_{layer_norm_name}"],
            [reduce_mean2],
            shape_layer_norm,
        )

        _, _, output_layer_names = self._get_output_properties(layer_norm)
        ew_mult1 = self.add_ew_mult_layer(
            f"{block_name}ew_mult1_{layer_norm_name}", output_layer_names, [resize2, pred_layer_norm]
        )
        self._set_new_output_of_layer(layer_norm, ew_mult1)
        self._remove_layer(layer_norm)

    # region helper functions
    def _get_output_properties(self, layer_norm):
        output_layer_name_full_list = self._model.flow.successors_sorted(layer_norm.full_name)
        indexes = [
            (
                self._model.flow.get_edge_input_index(layer_norm.full_name, output_layer),
                self._model.flow.get_edge_output_index(layer_norm.full_name, output_layer),
            )
            for output_layer in output_layer_name_full_list
        ]
        output_layer_names = [
            output_layer_name_full.split("/")[1] for output_layer_name_full in output_layer_name_full_list
        ]
        output_layers = [
            self._model.layers[output_layer_name_full] for output_layer_name_full in output_layer_name_full_list
        ]

        return indexes, output_layers, output_layer_names

    def _set_new_output_of_layer(self, layer_norm, last_layer):
        indexes, output_layers, _ = self._get_output_properties(layer_norm)

        for output_layer, (input_index, output_index) in zip(output_layers, indexes):
            self._model.flow.add_edge(
                last_layer.full_name,
                output_layer.full_name,
                input_index=input_index,
                output_index=output_index,
            )

        if layer_norm.full_name in self._model.flow.output_layer_order:
            index = self._model.flow._output_layer_order.index(layer_norm.full_name)
            self._model.flow._output_layer_order[index] = last_layer.full_name

    def _concatenate_and_sum_group_nudged(
        self, layer_norm_name, epsilon, concat_input_layer_list, out_var_name, norm_groups=1
    ):
        # concat- normalization_nudge(group conv layer) - conv_var_inv
        block_name, layer_norm_name = self.get_block_and_layer_names(layer_norm_name)
        concat_layer = self.add_concat_layer(
            f"{block_name}concat_layer_{layer_norm_name}",
            [f"{block_name}normalization_nudge_{layer_norm_name}"],
            concat_input_layer_list,
        )

        shape = concat_layer.to_hn()["output_shapes"][0]
        number_of_channels = shape[-1]
        number_of_channels_per_input = number_of_channels // 3

        if norm_groups > 1:
            num_groups_per_inp = norm_groups
        else:
            # This is a rough heuristic that helps with the performance of llama2_7b (that have 4096 channels per input)
            if number_of_channels_per_input < 4096:
                num_groups_per_inp = 4
            else:
                num_groups_per_inp = 8
        number_of_channels_to_sum = (
            number_of_channels_per_input // num_groups_per_inp
            if number_of_channels_per_input % num_groups_per_inp == 0
            else number_of_channels_per_input
        )  # note that for qwen its 512
        # TODO was designed to fully support qwen which is 512 but maybe should be calculated in a smarter way.

        factor = 1 / number_of_channels_to_sum

        if norm_groups > 1:
            activation = "linear"
            epsilon = 0
            conv_var_name = f"{block_name}channels_mean_var_{layer_norm_name}"
        else:
            activation = "inv_sqrt"
            conv_var_name = f"{block_name}conv_var_inv_{layer_norm_name}"
        ##########
        normalization_nudge = self.add_conv_reduce_sum_layer(
            f"{block_name}normalization_nudge_{layer_norm_name}",
            [conv_var_name],
            [concat_layer],
            0,
            number_of_channels_to_sum=number_of_channels_to_sum,
            activation="linear",
        )

        conv_var_inv = self.add_conv_mean_layer(
            conv_var_name,
            [out_var_name],
            [normalization_nudge],
            epsilon,
            activation,
            factors=(1 * factor, 1 * factor, 2 * factor),
            norm_groups=norm_groups,
        )

        conv_var_inv.conv_op.force_rounded_shift_delta = True  # TODO - we may not need this anynorm
        return conv_var_inv, normalization_nudge

    def _concatenate_and_sum(self, layer_norm_name, epsilon, concat_input_layer_list, out_var_name, norm_groups=4):
        """
        1. self._with_nudging and self._group_nudging:
            concat- normalization_nudge(group_conv) - conv_var_inv
        2.self._with_nudging:
            concat- normalization_nudge(normalization) - conv_var_inv
        3. no nudging:          concat- conv_var_inv
        """
        block_name, layer_norm_name = self.get_block_and_layer_names(layer_norm_name)
        if self._nudging and self._group_nudging:
            return self._concatenate_and_sum_group_nudged(
                f"{block_name}{layer_norm_name}",
                epsilon,
                concat_input_layer_list,
                out_var_name,
                norm_groups=norm_groups,
            )
        elif self._nudging:
            return self._concatenate_and_sum_nudged(
                f"{block_name}{layer_norm_name}", epsilon, concat_input_layer_list, out_var_name
            )

        concat_layer = self.add_concat_layer(
            f"{block_name}concat_layer_{layer_norm_name}",
            [f"{block_name}conv_var_inv_{layer_norm_name}"],
            concat_input_layer_list,
        )
        conv_var_inv = self.add_conv_mean_layer(
            f"{block_name}conv_var_inv_{layer_norm_name}",
            [out_var_name],
            [concat_layer],
            epsilon,
            "inv_sqrt",
            factors=(1, 1, 2),
        )
        return conv_var_inv, conv_var_inv

    # endregion
    def _concatenate_and_sum_nudged(self, layer_norm_name, epsilon, concat_input_layer_list, resize_name):
        """
        concat- normalization_nudge(normalization layer) - conv_var_inv
        """
        block_name, layer_norm_name = self.get_block_and_layer_names(layer_norm_name)
        concat_layer = self.add_concat_layer(
            f"{block_name}concat_layer_{layer_norm_name}",
            [f"{block_name}normalization_nudge_{layer_norm_name}"],
            concat_input_layer_list,
        )
        normalization_nudge = self.add_normalization_layer(
            f"{block_name}normalization_nudge_{layer_norm_name}",
            [f"{block_name}vconv_var_inv_{layer_norm_name}"],
            [concat_layer],
            factor=1,
        )
        conv_var_inv = self.add_conv_mean_layer(
            f"{block_name}conv_var_inv_{layer_norm_name}",
            [resize_name],
            [normalization_nudge],
            epsilon,
            "inv_sqrt",
            factors=(1, 1, 2),
        )
        return conv_var_inv, normalization_nudge

    # region decompose default functions
    def _decompose_mercury(self, layer_norm):
        # mercury scheme
        hn_element = layer_norm.to_hn()
        weights = layer_norm.export_weights()
        layer_name = layer_norm.full_name
        new_layer = HailoLayerNormMercury.from_hn(lname=layer_name, hn_element=hn_element)
        new_layer.import_weights(weights)
        self._model.layers[layer_norm.full_name] = new_layer
        self.add_config(new_layer, bias_mode="single_scale_decomposition", precision_mode="a8_w8")

    # region decompose default functions
    def _decompose_pluto(self, layer_norm):
        # mercury scheme
        hn_element = layer_norm.to_hn()
        weights = layer_norm.export_weights()
        layer_name = layer_norm.full_name
        new_layer = HailoLayerNorm.from_hn(lname=layer_name, hn_element=hn_element)
        new_layer.import_weights(weights)
        self._model.layers[layer_norm.full_name] = new_layer
        self.add_config(new_layer, bias_mode="single_scale_decomposition", precision_mode="a16_w16")

    # endregion
    def _get_nudged_factors(self, factor_candidate):
        factors = []
        version = "version1"
        if version == "version1":
            max_number_of_2 = 4
            max_number_of_3 = 2
            max_number_of_5 = 0
        elif version == "version2":
            max_number_of_2 = 2
            max_number_of_3 = 2
            max_number_of_5 = 1
        else:
            max_number_of_2 = 7
            max_number_of_3 = 0
            max_number_of_5 = 0
        for num2 in range(max_number_of_2 + 1):
            for num3 in range(max_number_of_3 + 1):
                for num5 in range(max_number_of_5 + 1):
                    factors.append(2**num2 * 3**num3 * 5**num5)

        def find_closest_smaller_element(a, b):
            result = []
            for num_a in a:
                closest_smaller = max(filter(lambda x: x <= num_a, b), default=None)
                result.append(closest_smaller)
            return np.array(result)

        result = find_closest_smaller_element(factor_candidate, factors)
        max_factor = 2**max_number_of_2 * 3**max_number_of_3 * 5**max_number_of_5

        def get_lcm(nums):
            # calculates the LCM (Least Common Multiple) of a list of numbers.
            # Use the 'reduce' function to apply the 'lcm' function cumulatively to the elements of 'nums'.
            return reduce(lambda x, y: lcm(x, y), nums)

        def gcd(a, b):
            # 'gcd' - Greatest Common Divisor (GCD) of two numbers.
            # Use the Euclidean algorithm to find the GCD of 'a' and 'b'.
            while b:
                a, b = b, a % b
            return a

        def lcm(a, b):
            # 'lcm' - Least Common Multiple (LCM) of two numbers.
            # Calculate the LCM using the formula: LCM(a, b) = (a * b) / GCD(a, b).
            return a * b // gcd(a, b)

        _ = get_lcm(np.unique(result))

        return result, max_factor

    # region optimization Functions
    def equalize_layer(self, lname):
        layer_norm_equiv_class = self._equalization_info_by_layer[lname]
        source_layer = self._model.layers[layer_norm_equiv_class.source]
        consumer_square_layer = self._model.layers[layer_norm_equiv_class.consumer_square]
        consumer_out_layer = self._model.layers[layer_norm_equiv_class.consumer_out]

        equalization_factors = self.get_equalization_factors(source_layer)

        equalization_source_kernel, c1_factor = self._get_nudged_factors(equalization_factors)
        equalization_consumer_out_kernel = 1 / equalization_source_kernel
        equalization_consumer_square_kernel = equalization_consumer_out_kernel**2

        ####### set default scale values
        scale_source_candidate = 1
        scale_consumer_out_candidate = 1 / c1_factor
        scale_consumer_square = scale_consumer_out_candidate**2

        max_accumulator_shift = self._calc_max_accumulator_shift(
            equalization_consumer_square_kernel, scale_consumer_square, consumer_square_layer.conv_op.groups
        )
        if max_accumulator_shift > 23:
            self._split_layer(self._model, lname)
            return

        equalization_source_kernel_q_candidate = equalization_source_kernel / scale_source_candidate
        equalization_consumer_out_kernel_q_candidte = equalization_consumer_out_kernel / scale_consumer_out_candidate

        scale_source, source_forced_output_factor = self._calc_needed_factor(
            source_layer,
            scale_source_candidate,
            equalization_source_kernel_q_candidate,
            factor_in=None,
            factor_out=equalization_source_kernel,
        )
        scale_consumer_out, consumer_forced_output_factor = self._calc_needed_factor(
            consumer_out_layer,
            scale_consumer_out_candidate,
            equalization_consumer_out_kernel_q_candidte,
            factor_in=equalization_source_kernel,
            factor_out=None,
        )

        ####################################################################

        equalization_source_kernel_q = equalization_source_kernel / scale_source
        equalization_consumer_out_kernel_q = equalization_consumer_out_kernel / scale_consumer_out
        equalization_consumer_square_kernel_q = equalization_consumer_square_kernel / scale_consumer_square

        max_accumulator_shift_floor = np.floor(max_accumulator_shift)
        val = -(max_accumulator_shift - 7)
        forced_output_factor_sqaure = 2**val

        self._import_new_kernel_q(source_layer, equalization_source_kernel, scale_source, source_forced_output_factor)
        self._import_new_kernel_q(
            consumer_out_layer, equalization_consumer_out_kernel, scale_consumer_out, consumer_forced_output_factor
        )
        self._import_new_kernel_q(
            consumer_square_layer,
            equalization_consumer_square_kernel,
            scale_consumer_square,
            forced_output_factor_sqaure,
        )

        self._split_layer(self._model, lname, equalization_source_kernel)

        if DEBUG:
            unique_sol = np.unique(equalization_source_kernel_q)
            diff = equalization_factors - equalization_source_kernel_q
            unique_sol = np.unique(equalization_source_kernel_q)
            unique_sol_out = np.unique(equalization_consumer_out_kernel_q)
            unique_sol_square = np.unique(equalization_consumer_square_kernel_q)

            print("diff", min(diff), max(diff))

            print("solution", unique_sol, len(unique_sol))
            print("unique_sol_out", unique_sol_out, len(unique_sol_out))
            print("unique_sol_square", unique_sol_square, len(unique_sol_square))

            expected_forced_output_factor = c1_factor**2
            print(
                f"max_accumulator_shift_floor {max_accumulator_shift_floor}  forced_output_factor_sqaure {forced_output_factor_sqaure} expected_forced_output_factor {expected_forced_output_factor} ratio {expected_forced_output_factor/forced_output_factor_sqaure}"
            )

    @staticmethod
    def get_equalization_factors_old(source_layer):
        """
        get equalization factors for the layer
        """
        input_stats = source_layer.get_input_stats()[0]
        epsilon = 1e-10

        stats_min = np.where(input_stats.min == 0, epsilon * -1, input_stats.min)
        stats_max = np.where(input_stats.max == 0, epsilon, input_stats.max)

        min_all = np.min(stats_min)
        max_all = np.max(stats_max)

        factor_max = np.maximum(max_all / stats_max, min_all / stats_max)
        factor_min = np.maximum(min_all / stats_min, max_all / stats_min)

        return np.minimum(factor_max, factor_min)

    @staticmethod
    def get_equalization_factors(source_layer):
        """
        get equalization factors for the layer
        """
        input_stats = source_layer.get_input_stats()[0]
        epsilon = 1e-10

        stats_abs = np.maximum(np.maximum(np.abs(input_stats.min), np.abs(input_stats.max)), epsilon)
        max_all = np.max(stats_abs)

        factor_all = max_all / stats_abs

        return factor_all

    @staticmethod
    def _calc_max_accumulator_shift(equalization_consumer_square_kernel, scale_consumer_square, groups):
        equalization_consumer_square_kernel_q = equalization_consumer_square_kernel / scale_consumer_square
        val = equalization_consumer_square_kernel_q.reshape((-1, (groups // 3)))
        sumation = np.sum(val, axis=0)
        log_2 = np.log2(sumation)
        max_accumulator_shift = np.max(log_2)
        return max_accumulator_shift

    def _calc_needed_factor(self, layer, kernel_scale_forced, kernel_q, factor_in=None, factor_out=None):
        """
        given the layer and the kernel_scale_forced and the equalization change  we want to calculate the expected output_scale and s
        et the new forced output factor (after nudgint)
        1. calc the expected output sacle - and nudge it if needed
        2. if there is needed shift- push it into the kernel scale forced

        """
        # get input_scales and output_scales after "equalization"
        input_scale, _ = layer.conv_op.calc_input_encoding_candidates(0, factor=factor_in)
        output_scale, _ = layer.output_op.calc_output_encoding_candidates(
            0, output_lossy_external=layer.output_lossy_element_external, factor=factor_out
        )

        ###### nudge the output scale if needed
        acc_scale = input_scale[:: layer.kernel.shape[-2]] * kernel_scale_forced
        output_factors_val = np.array(acc_scale / output_scale, dtype=layer.act_op.FLOAT_TYPE_NP)
        nonzero = 1.0
        exponent_factors, mantissas_candidate, exponents = layer.act_op._get_mantissa_exponent_decomposition(
            np.array([nonzero], dtype=layer.act_op.FLOAT_TYPE_NP), output_factors_val
        )
        mantissas = np.floor(mantissas_candidate)
        output_factor = np.squeeze(exponent_factors * mantissas * layer.act_op.final_shift_factor / nonzero)

        # push the shift into the kernel scale forced if needed
        act_op = layer.activation_atomic_op
        assigned_exp = act_op.get_assigned_exponent(-exponents)
        shift_fix_exp = np.max(-assigned_exp)
        shift_fix = np.max([0, np.max(shift_fix_exp)])

        max_shift = np.floor(np.log2((2**15 - 1) / np.max(kernel_q)))

        if shift_fix > max_shift:
            self.logger.info(f"the shift wnated to be {shift_fix} but is {max_shift}")
            shift_fix = max_shift

        forced_output_factor = output_factor[0] / (2.0**shift_fix)
        kernel_scale_forced /= 2**shift_fix

        return kernel_scale_forced, forced_output_factor

    @staticmethod
    def _split_layer(model, lname, factor=None):
        acceleras_layer = model.layers[lname]
        if isinstance(acceleras_layer, HailoPrecisionSplit):
            acceleras_layer.create_splits(factor=factor)

    def mask_exp_decompose(self, lname):
        MASK = np.array([2 ** (-6), 2 ** (-4), 2 ** (-2), 2 ** (-1), 2**0])

        exp_decompose = self._mask_exp_decompose_info_by_layer[lname].exp_decompose
        shift = self._mask_exp_decompose_info_by_layer[lname].shift

        # get input bits (signed so we need to substract 1)
        bits = self._model.layers[exp_decompose].get_input_lossy_elements()[0].bits - 1

        act_stats = self._model.layers[exp_decompose].act_op.get_input_stats(0)
        max_value = np.max(np.maximum(((2**bits) / (2**bits - 1)) * np.abs(act_stats.max), np.abs(act_stats.min)))
        min_value = -max_value

        # force ranges so that input/output are symetric, and exp_decompose is power of 2
        pred = self._model.flow.predecessors_sorted(lname)[0]
        self._force_range(pred, (((2**bits - 0.5) / (2**bits)) * min_value, ((2**bits - 1) / (2**bits)) * max_value))
        self._force_range(exp_decompose, (0, ((2 ** (bits + 1) - 1) / (2 ** (bits - 1))) * max_value))
        self._force_range(shift, (-1, ((2**bits - 1) / (2**bits))))

        mask = np.concatenate([MASK[::-1] * min_value, MASK * max_value])

        self._model.layers[exp_decompose].act_op.act_native_params["mask"] = mask
        self._model.layers[shift].act_op.act_native_params["mask"] = mask

    @classmethod
    def _import_new_kernel_q(cls, layer, equalization_kernel, kernel_scale_forced=1, forced_output_factor=None):
        # this function creates the forced kernel we want
        kernel = layer.get_kernel().numpy()
        if layer.groups > 1:
            equalization_kernel_factor = np.tile(equalization_kernel.reshape((-1, kernel.shape[2])).transpose(), 3)
        else:
            equalization_kernel_factor = np.expand_dims(equalization_kernel, 1)
        kernel_new = kernel * equalization_kernel_factor
        layer.conv_op.import_weights(kernel_new)

        kernel_q = kernel_new / kernel_scale_forced

        layer.conv_op.kernel_q_forced = kernel_q
        layer.conv_op.kernel_scale_forced = kernel_scale_forced
        layer.conv_op.kernel_scale_forced_to_save = True

        if forced_output_factor is not None:
            layer.forced_output_factor = forced_output_factor

    # endregion

    # region add layers
    def update_pred_layer(self, layer, output_names):
        pred_layer = self._model.layers[
            self._model.flow.predecessors_sorted(layer.full_name)[0]
        ]  # get conv layer(input of norm)
        model_name = layer.full_name.split("/")[0]
        # create new hn to
        new_hn = pred_layer.to_hn()
        new_hn["output"] = [f"{model_name}/{output_name}" for output_name in output_names]
        new_hn["output_shapes"] = [new_hn["output_shapes"][0] for _ in output_names]
        new_hn["quantization_params"] = {}

        pred_layer._hn_element = new_hn
        return pred_layer

    def add_normalization_layer(
        self,
        new_name,
        output_names,
        layers_inp,
        activation="linear",
        factor=1,
        bias_mode="single_scale_decomposition",
        precision_mode="a16_w16",
    ):
        layer_inp = layers_inp[0]
        model_name = layer_inp.full_name.split("/")[0]
        inp_hn = layer_inp.to_hn()
        shape = inp_hn["output_shapes"][0]
        input_channel = shape[-1]
        kernel_shape = [1, 1, input_channel, 1]
        hn_element = {
            "type": "normalization",
            "input": [layer_inp.full_name],
            "output": [f"{model_name}/{output_name}" for output_name in output_names],
            "input_shapes": [shape],
            "output_shapes": [shape],
            "compilation_params": {},
            "quantization_params": {},
            "params": {
                "kernel_shape": kernel_shape,
                "strides": [1, 1, 1, 1],
                "dilations": [1, 1, 1, 1],
                "padding": "VALID",
                "groups": 1,
                "layer_disparity": 1,
                "input_disparity": 1,
                "batch_norm": False,
                "elementwise_add": False,
                "activation": activation,
            },
        }
        kernel = np.ones(input_channel, dtype=np.float32) * factor
        kernel = kernel.reshape(kernel_shape)
        bias = np.zeros([shape[-1]], dtype=np.float32)
        weights = {"kernel": kernel, "bias": bias}
        new_layer = self._add_generic_layer(hn_element, model_name, new_name, layers_inp, weights)
        new_layer.trainable = False
        new_layer.bias_add_op.is_correctable = False
        self.add_config(new_layer, bias_mode=bias_mode, precision_mode=precision_mode)

        return new_layer

    def add_shortcut_layer(self, new_name, output_names, layers_inp, output_index=0):
        """
        add shortcut layer between layer norm and conv ( or whatever it is)
        """
        layer_inp = layers_inp[0]

        model_name = layer_inp.full_name.split("/")[0]
        inp_hn = layer_inp.to_hn()
        shape = inp_hn["output_shapes"][0]
        hn_element = {
            "type": "shortcut",
            "input": [layer_inp.full_name],
            "output": [f"{model_name}/{output_name}" for output_name in output_names],
            "input_shapes": [shape],
            "output_shapes": [shape],
        }
        # add layer to model
        weights = dict()
        new_layer = self._add_generic_layer(
            hn_element,
            model_name,
            new_name,
            layers_inp,
            weights,
            output_index=output_index,
        )

        return new_layer

    def _add_precision_change(self, new_name, output_names, layers_inp):
        layer_inp = layers_inp[0]

        model_name = layer_inp.full_name.split("/")[0]
        inp_hn = layer_inp.to_hn()
        shape = inp_hn["output_shapes"][0]

        hn_element = {
            "type": "activation",
            "input": [layer_inp.full_name],
            "output": [f"{model_name}/{output_name}" for output_name in output_names],
            "input_shapes": [shape],
            "output_shapes": [shape],
            "params": {"activation": "linear"},
        }
        new_layer = self._add_generic_layer(hn_element, model_name, new_name, layers_inp)
        self.add_config(new_layer, bias_mode="single_scale_decomposition", precision_mode="a8_w8_a16")
        return new_layer

    def _add_precision_split_layer(self, new_name, output_names, layers_inp):
        """
        add add_precision_split_layer layer between layer norm and conv ( or whatever it is)
        """
        layer_inp = layers_inp[0]

        model_name = layer_inp.full_name.split("/")[0]
        inp_hn = layer_inp.to_hn()
        shape = inp_hn["output_shapes"][0]
        hn_element = {
            "type": "precision_splitter",
            "input": [layer_inp.full_name],
            "output": [f"{model_name}/{output_name}" for output_name in output_names],
            "input_shapes": [shape],
            "output_shapes": [shape, shape],
        }
        new_layer = self._add_generic_layer(hn_element, model_name, new_name, layers_inp)
        self.add_config(new_layer, bias_mode="single_scale_decomposition", precision_mode="a16_w16")
        return new_layer

    def add_spatial_reduce_mean_layer(
        self,
        new_name,
        output_names,
        layers_inp,
        epsilon=None,
        activation="linear",
        bias_mode="double_scale_initialization",
        precision_mode="a8_w8",
    ):
        layer_inp = layers_inp[0]
        inp_hn = layer_inp.to_hn()
        shape = inp_hn["output_shapes"][0][1:3]

        add_two_spatial = np.prod(shape) > 256**2  # to avoid over flow
        if add_two_spatial:
            block_name, layer_name = self.get_block_and_layer_names(new_name)
            second_name = f"{block_name}second_{layer_name}"
            first_reduce_mean = self._add_one_spatial_reduce_mean_layer(
                new_name,
                [second_name],
                layers_inp,
                [2],
                epsilon=None,
                activation="linear",
                bias_mode=bias_mode,
                precision_mode=precision_mode,
            )
            second_reduce_mean = self._add_one_spatial_reduce_mean_layer(
                second_name,
                output_names,
                [first_reduce_mean],
                [1],
                epsilon=epsilon,
                activation=activation,
                bias_mode=bias_mode,
                precision_mode=precision_mode,
            )
        else:
            second_reduce_mean = self._add_one_spatial_reduce_mean_layer(
                new_name,
                output_names,
                layers_inp,
                [1, 2],
                epsilon=epsilon,
                activation=activation,
                bias_mode=bias_mode,
                precision_mode=precision_mode,
            )
        return second_reduce_mean

    def _add_one_spatial_reduce_mean_layer(
        self,
        new_name,
        output_names,
        layers_inp,
        reduce_axes,
        epsilon=None,
        activation="linear",
        bias_mode="double_scale_initialization",
        precision_mode="a8_w8",
    ):
        layer_inp = layers_inp[0]
        model_name = layer_inp.full_name.split("/")[0]
        inp_hn = layer_inp.to_hn()
        shape = inp_hn["output_shapes"][0]
        output_shape = shape[-1]
        if tuple(reduce_axes) not in {(1, 2), (1,), (2,)}:
            raise ValueError(" reduce_axes is not [1,2], [1], [2]")

        out_shape = shape.copy()
        out_shape[-1] = output_shape
        for axis in reduce_axes:
            out_shape[axis] = 1
        hn_element = {
            "type": "reduce_mean",
            "input": [layer_inp.full_name],
            "output": [f"{model_name}/{output_name}" for output_name in output_names],
            "input_shapes": [shape],
            "output_shapes": [out_shape for _ in output_names],
            "compilation_params": {},
            "quantization_params": {},
            "params": {
                "reduce_axes": reduce_axes,
                "activation": activation,
            },
        }
        weights = None
        if epsilon is not None:
            bias = np.ones([1]) * epsilon
            weights = {"bias": bias}

        new_layer = self._add_generic_layer(hn_element, model_name, new_name, layers_inp, weights=weights)
        self.add_config(new_layer, bias_mode=bias_mode, precision_mode=precision_mode)
        return new_layer

    def add_reduce_mean_layer(
        self,
        new_name,
        output_names,
        layers_inp,
        epsilon=None,
        activation="linear",
        groups=1,
        bias_mode="double_scale_initialization",
        precision_mode="a8_w8",
    ):
        layer_inp = layers_inp[0]
        model_name = layer_inp.full_name.split("/")[0]
        inp_hn = layer_inp.to_hn()
        shape = inp_hn["output_shapes"][0]
        out_shape = shape[:-1] + [groups]
        hn_element = {
            "type": "reduce_mean",
            "input": [layer_inp.full_name],
            "output": [f"{model_name}/{output_name}" for output_name in output_names],
            "input_shapes": [shape],
            "output_shapes": [out_shape for _ in output_names],
            "compilation_params": {},
            "quantization_params": {},
            "params": {
                "reduce_axes": [3],
                "groups": groups,
                "activation": activation,
            },
        }
        weights = None
        if epsilon is not None:
            bias = np.ones([1]) * epsilon
            weights = {"bias": bias}

        new_layer = self._add_generic_layer(hn_element, model_name, new_name, layers_inp, weights=weights)
        self.add_config(new_layer, bias_mode=bias_mode, precision_mode=precision_mode)
        return new_layer

    def add_conv_reduce_sum_layer(
        self,
        new_name,
        output_names,
        layers_inp,
        epsilon,
        number_of_channels_to_sum=256,
        activation="linear",
        bias_mode="single_scale_decomposition",
        precision_mode="a16_w16",
    ):
        layer_inp = layers_inp[0]
        model_name = layer_inp.full_name.split("/")[0]
        inp_hn = layer_inp.to_hn()
        shape = inp_hn["output_shapes"][0]

        input_channel = shape[-1]
        ### 6144 / 128= 48
        number_of_channels_out = input_channel // number_of_channels_to_sum
        kernel_shape_for_hn = [1, 1, input_channel, number_of_channels_out]
        output_shape = shape[:-1] + [number_of_channels_out]

        hn_element = {
            "type": "conv",
            "input": [layer_inp.full_name],
            "output": [f"{model_name}/{output_name}" for output_name in output_names],
            "input_shapes": [shape],
            "output_shapes": [output_shape for _ in output_names],
            "compilation_params": {},
            "quantization_params": {},
            "params": {
                "kernel_shape": kernel_shape_for_hn,
                "strides": [1, 1, 1, 1],
                "dilations": [1, 1, 1, 1],
                "padding": "VALID",
                "groups": number_of_channels_out,
                "layer_disparity": 1,
                "input_disparity": 1,
                "batch_norm": False,
                "elementwise_add": False,
                "activation": activation,
            },
        }
        # num_for_fact_each_group = input_channel//3 # for each group
        kernel_shape = [1, 1, number_of_channels_to_sum, number_of_channels_out]

        # Create a (1,1, 128, 48) kernel filled with zeros
        kernel = np.ones(shape=(1, 1, number_of_channels_to_sum, number_of_channels_out))

        if DEBUG:
            print("kernel shape", kernel.shape, kernel_shape)
            print("input_shapes", shape)

            print("output shape", output_shape)

            print("number of groups", number_of_channels_out)
            print("number_of_channels_to_sum", number_of_channels_to_sum)

        bias = np.ones(shape=number_of_channels_out) * epsilon
        weights = {"kernel": kernel, "bias": bias}

        new_layer = self._add_generic_layer(hn_element, model_name, new_name, layers_inp, weights)
        new_layer.trainable = False
        new_layer.bias_add_op.is_correctable = False

        self.add_config(new_layer, bias_mode=bias_mode, precision_mode=precision_mode)
        return new_layer

    def add_conv_mean_layer(
        self,
        new_name,
        output_names,
        layers_inp,
        epsilon,
        activation="linear",
        factors=(1, 1, 2),
        bias_mode="single_scale_decomposition",
        precision_mode="a16_w16",
        norm_groups=1,
    ):
        layer_inp = layers_inp[0]
        model_name = layer_inp.full_name.split("/")[0]
        inp_hn = layer_inp.to_hn()
        shape = inp_hn["output_shapes"][0]

        input_channel = shape[-1]
        kernel_shape = [1, 1, input_channel, norm_groups]
        hn_element = {
            "type": "conv",
            "input": [layer_inp.full_name],
            "output": [f"{model_name}/{output_name}" for output_name in output_names],
            "input_shapes": [shape],
            "output_shapes": [shape[:-1] + [norm_groups] for _ in output_names],
            "compilation_params": {},
            "quantization_params": {},
            "params": {
                "kernel_shape": kernel_shape,
                "strides": [1, 1, 1, 1],
                "dilations": [1, 1, 1, 1],
                "padding": "VALID",
                "groups": 1,
                "layer_disparity": 1,
                "input_disparity": 1,
                "batch_norm": False,
                "elementwise_add": False,
                "activation": activation,
            },
        }
        kernels = []
        num_for_fact = input_channel // len(factors) // norm_groups
        for fact in factors:
            fact_to = fact / num_for_fact
            if norm_groups > 1:
                kernel = np.eye(norm_groups) * fact_to
            else:
                kernel = np.repeat(fact_to, num_for_fact)
            kernels.append(kernel)
        kernel = np.concatenate(kernels)
        kernel = kernel.reshape(kernel_shape)
        bias = np.ones([norm_groups]) * epsilon
        weights = {"kernel": kernel, "bias": bias}

        new_layer = self._add_generic_layer(hn_element, model_name, new_name, layers_inp, weights)
        new_layer.trainable = False
        new_layer.bias_add_op.is_correctable = False

        self.add_config(new_layer, bias_mode=bias_mode, precision_mode=precision_mode)
        return new_layer

    def add_resize_layer(
        self,
        new_name,
        output_names,
        layers_inp,
        layer_norm_shape,
        bias_mode="single_scale_decomposition",
        precision_mode="a8_w8",
        channels=True,
    ):
        layer_inp = layers_inp[0]
        model_name = layer_inp.full_name.split("/")[0]
        inp_hn = layer_inp.to_hn()

        inp_shape = inp_hn["output_shapes"][0]
        if channels:
            resize_h_ratio_list = 1.0
            resize_w_ratio_list = 1.0
            resize_f_ratio_list = float(layer_norm_shape[3] / inp_shape[3])
        else:
            resize_h_ratio_list = float(layer_norm_shape[1] / inp_shape[1])
            resize_w_ratio_list = float(layer_norm_shape[2] / inp_shape[2])
            resize_f_ratio_list = 1.0

        output_shape = [
            inp_shape[0],
            int(resize_h_ratio_list) * inp_shape[1],
            int(resize_w_ratio_list) * inp_shape[2],
            int(resize_f_ratio_list) * inp_shape[3],
        ]
        output_shapes = [output_shape for _ in output_names]
        hn_element = {
            "type": "resize",
            "input": [layer_inp.full_name],
            "output": [f"{model_name}/{output_name}" for output_name in output_names],
            "input_shapes": [inp_shape],
            "output_shapes": output_shapes,
            "compilation_params": {
                "hw_layer_type_list": ["lcu"],
            },
            "quantization_params": {},
            "params": {
                "resize_h_ratio_list": [resize_h_ratio_list],
                "resize_w_ratio_list": [resize_w_ratio_list],
                "resize_f_ratio_list": [resize_f_ratio_list],
                "method": "nearest_neighbor",
                "resize_bilinear_pixels_mode": "disabled",
            },
        }

        new_layer = self._add_generic_layer(hn_element, model_name, new_name, layers_inp)
        self.add_config(new_layer, bias_mode=bias_mode, precision_mode=precision_mode)
        return new_layer

    def add_ew_sub_layer(
        self,
        new_name,
        output_names,
        layers_inp,
        bias_mode="double_scale_initialization",
        precision_mode="a8_w8",
    ):
        model_name = layers_inp[0].full_name.split("/")[0]
        inp_hn = layers_inp[0].to_hn()
        shape = inp_hn["output_shapes"][0]
        hn_element = {
            "type": "ew_sub",
            "input": [layers_inp[0].full_name, layers_inp[1].full_name],
            "output": [f"{model_name}/{output_name}" for output_name in output_names],
            "input_shapes": [shape, shape],
            "output_shapes": [shape for _ in output_names],
            "compilation_params": {},
            "quantization_params": {},
            "params": {
                "activation": "linear",
            },
        }

        # add layer to model

        new_layer = self._add_generic_layer(hn_element, model_name, new_name, layers_inp)
        self.add_config(new_layer, bias_mode=bias_mode, precision_mode=precision_mode)

        return new_layer

    def add_ew_add_layer(self, new_name, output_names, layers_inp):
        model_name = layers_inp[0].full_name.split("/")[0]
        inp_hn = layers_inp[0].to_hn()
        shape = inp_hn["output_shapes"][0]
        hn_element = {
            "type": "ew_add",
            "input": [layers_inp[0].full_name, layers_inp[1].full_name],
            "output": [f"{model_name}/{output_name}" for output_name in output_names],
            "input_shapes": [shape, shape],
            "output_shapes": [shape for _ in output_names],
            "compilation_params": {},
            "quantization_params": {},
            "params": {
                "activation": "linear",
            },
        }

        # add layer to model
        new_layer = self._add_generic_layer(hn_element, model_name, new_name, layers_inp)
        # self.add_config(new_layer, bias_mode="double_scale_initialization", precision_mode="a8_w8")

        return new_layer

    def add_ew_mult_layer(
        self,
        new_name,
        output_names,
        layers_inp,
        mock_kernel_values=[2, 2],
        bias_mode="double_scale_initialization",
        precision_mode="a8_w8",
    ):
        model_name = layers_inp[0].full_name.split("/")[0]
        inp_hn = layers_inp[0].to_hn()
        shape = inp_hn["output_shapes"][0]

        hn_element = {
            "type": "ew_mult",
            "input": [layers_inp[0].full_name, layers_inp[1].full_name],
            "output": [f"{model_name}/{output_name}" for output_name in output_names],
            "input_shapes": [shape, shape],
            "output_shapes": [shape for _ in output_names],
            "compilation_params": {},
            "quantization_params": {},
            "params": {
                "activation": "linear",
            },
        }

        new_layer = self._add_generic_layer(hn_element, model_name, new_name, layers_inp)
        new_layer.mock_kernel_values = mock_kernel_values
        self.add_config(new_layer, bias_mode=bias_mode, precision_mode=precision_mode)
        return new_layer

    def add_concat_layer(self, new_name, output_names, layers_inp):
        layer_inp = layers_inp[0]

        model_name = layer_inp.full_name.split("/")[0]
        inp_hn = layer_inp.to_hn()
        shape = inp_hn["output_shapes"][0]

        num_to = len(layers_inp)
        hn_element = {
            "type": "concat",
            "input": [layers_inp[0].full_name, layers_inp[1].full_name, layers_inp[2].full_name],
            "output": [f"{model_name}/{output_names[0]}"],
            "input_shapes": [shape] * num_to,
            "output_shapes": [shape[:-1] + [shape[-1] * num_to]],
            "original_names": ["/Concat"],
            "compilation_params": {},
            "quantization_params": {},
            "params": {
                "concat_axis": "features",
            },
        }

        new_layer = self._add_generic_layer(hn_element, model_name, new_name, layers_inp)
        new_layer.atomic_op.vector_zp = True
        return new_layer

    def add_square_layer(
        self,
        new_name,
        output_names,
        layers_inp,
        mock_kernel_values=[2, 2],
        bias_mode="double_scale_initialization",
        precision_mode="a8_w8",
    ):
        layer_inp = layers_inp[0]
        model_name = layer_inp.full_name.split("/")[0]
        inp_hn = layer_inp.to_hn()
        shape = inp_hn["output_shapes"][0]

        hn_element = {
            "type": "feature_multiplier",
            "input": [layer_inp.full_name],
            "output": [f"{model_name}/{output_names[0]}"],
            "input_shapes": [shape],
            "output_shapes": [shape],
            "original_names": ["/blocks/blocks.0/norm2/LayerNormalization"],
            "compilation_params": {},
            "quantization_params": {},
            "params": {
                "activation": "linear",
                "feature_multiplier_type": "square",
            },
        }
        # add layer to model
        new_layer = self._add_generic_layer(hn_element, model_name, new_name, layers_inp)
        new_layer.mock_kernel_values = mock_kernel_values
        self.add_config(new_layer, bias_mode=bias_mode, precision_mode=precision_mode)

        return new_layer

    def add_standalone_activation(
        self, new_name, output_names, layers_inp, activation="linear", precision_mode="a8_w8"
    ):
        layer_inp = layers_inp[0]

        model_name = layer_inp.full_name.split("/")[0]
        inp_hn = layer_inp.to_hn()
        shape = inp_hn["output_shapes"][0]

        hn_element = {
            "type": "activation",
            "input": [layer_inp.full_name],
            "output": [f"{model_name}/{output_name}" for output_name in output_names],
            "input_shapes": [shape],
            "output_shapes": [shape for _ in output_names],
            "params": {"activation": activation},
        }
        new_layer = self._add_generic_layer(hn_element, model_name, new_name, layers_inp)
        self.add_config(new_layer, bias_mode="double_scale_initialization", precision_mode=precision_mode)
        return new_layer

    def add_reduce_max(self, new_name, output_names, layers_inp, reduce_axes, groups=1, precision_mode="a8_w8"):
        layer_inp = layers_inp[0]

        model_name = layer_inp.full_name.split("/")[0]
        inp_hn = layer_inp.to_hn()
        shape = inp_hn["output_shapes"][0]

        if tuple(reduce_axes) not in {(1, 2), (1,), (2,), (3,)}:
            raise ValueError(" reduce_axes is not [1,2], [1], [2], [3]")
        if tuple(reduce_axes) in {(1, 2), (1,), (2,)} and groups != 1:
            raise ValueError(" reduce_axes is [1,2], [1], [2] and groups is not 1")

        shape_out = shape.copy()
        for axis in reduce_axes:
            shape_out[axis] = groups

        hn_element = {
            "type": "reduce_max",
            "input": [layer_inp.full_name],
            "output": [f"{model_name}/{output_name}" for output_name in output_names],
            "input_shapes": [shape],
            "output_shapes": [shape_out for _ in output_names],
            "params": {"groups": groups, "reduce_axes": reduce_axes},
        }
        new_layer = self._add_generic_layer(hn_element, model_name, new_name, layers_inp)
        self.add_config(new_layer, precision_mode=precision_mode)
        return new_layer

    def add_shift_layer(self, new_name, output_names, layers_inp, precision_mode="a8_w8"):
        model_name = layers_inp[0].full_name.split("/")[0]
        shape_in0 = layers_inp[0].to_hn()["output_shapes"][0]
        shape_in1 = layers_inp[1].to_hn()["output_shapes"][0]

        shape_out = np.maximum(shape_in0, shape_in1).tolist()
        input_repeats = [
            [dim_out // dim_in for dim_in, dim_out in zip(shape_in0[1:], shape_out[1:])],
            [dim_out // dim_in for dim_in, dim_out in zip(shape_in1[1:], shape_out[1:])],
        ]

        hn_element = {
            "type": "ew_add",
            "input": [layer_inp.full_name for layer_inp in layers_inp],
            "output": [f"{model_name}/{output_name}" for output_name in output_names],
            "input_shapes": [shape_in0, shape_in1],
            "output_shapes": [shape_out for _ in output_names],
            "params": {
                "activation": "shift",
                "input_repeats": input_repeats,
            },
        }
        new_layer = self._add_generic_layer(hn_element, model_name, new_name, layers_inp)
        self.add_config(new_layer, bias_mode="double_scale_initialization", precision_mode=precision_mode)
        return new_layer

    # endregion

    # region add generic function

    def _add_generic_layer(self, hn_element, model_name, new_name, layers_inp, weights=None, output_index=0):
        layer_name = f"{model_name}/{new_name}"
        new_layer = gen_acceleras_layers_from_hn(layer_name, hn_element, self.optimization_target)[layer_name]
        if weights is None:
            weights = dict()
        new_layer.import_weights(weights)
        self._add_preds(new_layer, layers_inp, output_index=output_index)
        self._model_config.equalization.layers[new_layer.full_name] = LayerEqualizationConfig(policy="disabled")
        self._model_config.adaround.layers[new_layer.full_name] = LayerAdaRoundConfig(policy="disabled")
        self._model_config.negative_exponent.layers[new_layer.full_name] = LayerNegExponentConfig(rank=0)
        self._model_config.bias_correction.layers[new_layer.full_name] = LayerBiasCorrectionConfig(policy="disabled")
        self._model_config.zero_static_channels.layers[new_layer.full_name] = LayerZeroStaticChannelsConfig(
            policy="disabled",
        )
        return new_layer

    def _add_preds(self, new_layer, predecessors, output_index=0):
        self._model.layers[new_layer.full_name] = new_layer
        node = new_layer.full_name
        self._model.flow.add_node(node)
        for i, predecessor in enumerate(predecessors):
            self._model.flow.add_edge(predecessor.full_name, node, input_index=i, output_index=output_index)

    def _remove_layer(self, layer):
        self._model.layers.pop(layer.full_name, None)
        self._model.flow.remove_node(layer.full_name)
        self._model_config.remove_layer_from_all_configs(layer.full_name)

    def add_config(self, new_layer, bias_mode="single_scale_decomposition", precision_mode="a16_w16"):
        layer_name = new_layer.full_name
        cfg = LayerPrecisionConfig(precision_mode=precision_mode, bias_mode=bias_mode, quantization_groups=1)
        self._model_config.precision_config.layers[layer_name] = cfg
        new_layer.import_precision_config(cfg, self.optimization_target)

    def _force_range(self, lname, range):
        translation_layer_config = self.finalize_layer_cfg(self._model_config.translation_config.layers)
        self._model_config.translation_config.layers[lname] = translation_layer_config.get(
            lname, LayerTranslationConfig.get_default()
        )
        meta = self._model_config.translation_config.layers[lname].meta
        if meta is None:
            meta = dict()
        if "force_range_out" not in meta.keys():
            meta["force_range_out"] = CommandMeta(line=-1, command="", is_glob=False)
            self._model_config.translation_config.layers[lname].force_range_out = range
            self._model_config.translation_config.layers[lname].meta = meta

    # endregion

    # region - debug section
    def equalize_layer_old(self, lname):
        # this function is the equalize layer function that was used in hte past.
        # me may want to use it in the future

        def _import_new_kernel(layer, factor_kernel):
            # given layer and new factor multi[ply the kernel by factor
            kernel = layer.get_kernel().numpy()
            new_kernel = kernel * np.expand_dims(factor_kernel, 1)
            layer.conv_op.import_weights(new_kernel)

        def _get_nudges_factor(layer, factor_candidate):
            kernel = layer.get_kernel()
            kernel_candidate = kernel * np.expand_dims(factor_candidate, 1)
            kernel_nudged = layer.get_nudged_kernel(kernel_candidate)
            factor_nudged = kernel_nudged / kernel
            return np.squeeze(factor_nudged)

        layer_norm_equiv_class = self._equalization_info_by_layer[lname]
        source_layer = self._model.layers[layer_norm_equiv_class.source]
        consumer_square_layer = self._model.layers[layer_norm_equiv_class.consumer_square]
        consumer_out_layer = self._model.layers[layer_norm_equiv_class.consumer_out]

        # consumer_square_layer.shape
        source_shape = source_layer.input_shape[-1]
        nudging_shape = consumer_square_layer.input_shape[-1]
        tiles = nudging_shape // source_shape

        input_stats = source_layer.get_input_stats()[0]
        min_all = np.min(input_stats.min)
        max_all = np.max(input_stats.max)

        factor_max = np.maximum(max_all / input_stats.max, min_all / input_stats.max)
        factor_min = np.maximum(min_all / input_stats.min, max_all / input_stats.min)

        factor_all = np.minimum(factor_max, factor_min)
        equalization_source_factor = np.minimum(factor_all, 2**6)
        equalization_consumer_factor = np.tile((1 / equalization_source_factor) ** 2, tiles)

        if self._nudging:
            equalization_consumer_factor = _get_nudges_factor(consumer_square_layer, equalization_consumer_factor)
            equalization_source_factor = 1 / np.sqrt(equalization_consumer_factor[:source_shape])

        if DEBUG:
            limvals_old = (np.min(input_stats.min), np.max(input_stats.max))
            limvals_new = (np.min(input_stats.min * factor_all), np.max(input_stats.max * factor_all))
            limvals_new1 = (
                np.min(input_stats.min * equalization_source_factor),
                np.max(input_stats.max * equalization_source_factor),
            )
            print("limvals_old", limvals_old)
            print("limvals_new", limvals_new)

            print(tiles, source_layer.input_shape, consumer_square_layer.input_shape)

            print("limvals_new1", limvals_new1)
            print("factor", np.max(equalization_source_factor))
            print("factor_all", np.max(factor_all))

        _import_new_kernel(source_layer, equalization_source_factor)
        _import_new_kernel(consumer_square_layer, equalization_consumer_factor)
        _import_new_kernel(consumer_out_layer, 1 / equalization_source_factor)
        self._split_layer(self._model, lname, equalization_source_factor)

    # endregion
