import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import CONCAT_AXIS_TO_DIM, FormatConversionType
from hailo_model_optimization.acceleras.utils.modification_tracker.modification_tracker_utils import (
    FoldNormalizationTracker,
)
from hailo_sdk_client.numeric_translator.bn_to_params import batch_norm_rescale_params, is_bn_info_param_key
from hailo_sdk_client.post_fuser.algorithms.exceptions import NormalizationOptimizingException
from hailo_sdk_client.post_fuser.algorithms.fuser_algorithm import FuserAlgorithm
from hailo_sdk_client.tools.fuser.fuser_helper import FuserHelper
from hailo_sdk_common.hailo_nn.hailo_nn import hn_to_npz_key
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType, LayerType
from hailo_sdk_common.hailo_nn.hn_layers.normalization import NormalizationLayer
from hailo_sdk_common.model_params.model_params import ModelParams


class NormalizationOptimizer(FuserAlgorithm):
    NAME = "normalization_optimizer"

    def __init__(self, model, params, model_config, hw_arch, **kwargs):
        super().__init__(model, params, model_config, hw_arch)
        self._bn_info = {}
        self._fuser_helper = FuserHelper(self._model)
        self.layers_with_folded_bn = []  # currently cant fold multiple bn over the same layer https://hailotech.atlassian.net/browse/SDK-57552

    def get_algo_config(self):
        return self._model_config

    def _setup(self):
        pass

    def _run_int(self):
        self._move_normalization_layers_before_folding()
        self._fold_post_layer_normalization_layers()
        self._fold_post_layer_batch_norm_layers()
        self._fold_pre_layer_normalization_layers()
        self._switch_norm_slice_to_slice_norm()
        self._model.update_input_lists()

    def should_skip_algo(self):
        return False

    def log_config(self):
        pass

    def export_statistics(self):
        return self._bn_info

    def _move_normalization_layers_before_folding(self):
        """
        Move normalization layers to be before / after fuseable layers.
        """
        self._move_normalization_layers_to_neighbor()
        self._move_normalization_layers_before_unfuseable_layers()
        self._move_normalization_layers_after_unfuseable_layers()
        self._move_normalization_layers_to_second_degree_const_inputs_preds()

    def _move_normalization_layers_before_unfuseable_layers(self):
        """
        Move normalization layers to be before non-fuseable layers if their predecessors are fuseable.
        """
        for layer in list(self._model):
            if layer.name in self._params and (
                layer.op in [LayerType.normalization, LayerType.batch_norm]
                or (
                    layer.op == LayerType.dw
                    and not layer.dynamic_weights  # dynamic_weights, weights are not stored in npz
                )
            ):
                bias = self._params[layer.name].get("bias:0")
                first_degree_preds = list(self._model.predecessors(layer))
                if len(first_degree_preds) == 1 and not (
                    first_degree_preds[0].op == LayerType.concat
                    or (first_degree_preds[0].op in [LayerType.ew_add, LayerType.ew_sub])
                    or (first_degree_preds[0].op in [LayerType.ew_mult, LayerType.ew_div] and np.all(bias == 0))
                ):
                    continue

                if len(list(self._model.successors(first_degree_preds[0]))) > 1:
                    continue

                # TODO: https://hailotech.atlassian.net/browse/SDK-55426
                if layer.activation != ActivationType.linear:
                    continue

                if layer.op != LayerType.batch_norm:
                    second_degree_preds = list(self._model.predecessors(first_degree_preds[0]))
                    can_fuse = self._can_fuse_post_layer_normalization(
                        layer, second_degree_preds, is_multi_pred_layer=len(second_degree_preds) > 1
                    )
                    if not can_fuse:
                        continue

                    # duplicate the normalization layer per pred by creating a new normalization with same params
                    self._fuser_helper.swap_layers_order(layer, first_degree_preds[0], is_layer1_first=False)
                    norm_succs = list(self._model.successors(layer))
                    for pred in second_degree_preds[1:]:
                        relative_input_idx = second_degree_preds.index(pred)
                        new_norm_layer = self._duplicate_normalization_layer(
                            reference_normalization=layer,
                            duplication_idx=relative_input_idx,
                        )
                        self._fuser_helper.add_preds(new_norm_layer, [pred])
                        self._fuser_helper.add_succs(new_norm_layer, norm_succs)
                        for norm_succ in norm_succs:
                            self._fuser_helper.add_preds(norm_succ, [new_norm_layer])
                        self._fuser_helper.replace_succ(pred, layer, new_norm_layer)
                        self._fuser_helper.remove_pred(layer, pred, update_input_shapes=False)
                        self._update_new_norm_params(
                            new_norm_layer.name, layer, first_degree_preds[0], relative_input_idx=relative_input_idx
                        )
                        self._model.relax_new_layer_into_graph(layer, {})
                    self._update_new_norm_params(layer.name, layer, first_degree_preds[0], relative_input_idx=0)

    def _duplicate_normalization_layer(self, reference_normalization, duplication_idx):
        """
        Duplicate the normalization layer with the same params.
        """
        new_norm_name = f"duplication_{duplication_idx}"
        new_norm_layer = self._fuser_helper.create_layer(
            NormalizationLayer,
            self._model.get_next_index(),
            new_norm_name,
            reference_normalization,
            [],
            reference_normalization.output_shapes,
        )

        return new_norm_layer

    def _move_normalization_layers_after_unfuseable_layers(self):
        """
        Move normalization layers to be after non-fuseable layers if their successors are fuseable.
        """
        for layer in list(self._model):
            if layer.op in [LayerType.normalization, LayerType.dw]:
                first_degree_succs = list(self._model.successors(layer))
                if len(first_degree_succs) > 1 or (
                    first_degree_succs[0].op != LayerType.slice
                    and not (
                        first_degree_succs[0].op == LayerType.format_conversion
                        and first_degree_succs[0].conversion_type
                        in [FormatConversionType.flat_to_frames, FormatConversionType.frames_to_flat]
                    )
                ):
                    continue

                second_degree_succs = list(self._model.successors(first_degree_succs[0]))
                if not self._can_fuse_pre_layer_normalization(
                    layer, preds=list(self._model.predecessors(layer)), succs=second_degree_succs
                )[0]:
                    continue

                self._fuser_helper.swap_layers_order(layer, first_degree_succs[0], is_layer1_first=True)
                self._update_new_norm_params(layer.name, layer, first_degree_succs[0])

    def _move_normalization_layers_to_neighbor(self):
        """
        In case of standalone normalization layer after unfuseable layer and before ew_op,
        move it to the neighbor layer if the neighbor is a fuseable layer or a is const.
        """
        for layer in list(self._model):
            if layer.name in self._params and (
                layer.op == LayerType.normalization
                or (
                    layer.op == LayerType.dw and not layer.dynamic_weights and layer.activation == ActivationType.linear
                )
            ):
                preds = list(self._model.predecessors(layer))
                succs = list(self._model.successors(layer))
                if len(preds) != 1 or len(succs) != 1:
                    continue

                kernel = self._params[layer.name].get("kernel:0")
                bias = self._params[layer.name].get("bias:0")
                pred = preds[0]
                succ = succs[0]
                if not (
                    (succ.op == LayerType.ew_add and np.all(kernel == 1))
                    or (succ.op == LayerType.ew_sub and np.all(kernel == 1))
                    or (succ.op == LayerType.ew_mult and np.all(bias == 0))
                    or (succ.op == LayerType.ew_div and np.all(bias == 0))
                ):
                    continue

                ew_op_preds = list(self._model.predecessors(succ))
                ew_op_preds.remove(layer)
                folding_candidate = None
                for ew_op_pred in ew_op_preds:
                    if self._can_fuse_post_layer_normalization(layer, [ew_op_pred], is_multi_pred_layer=False):
                        folding_candidate = ew_op_pred
                        break
                if not folding_candidate:
                    continue

                self._move_normalization(layer, old_pred=pred, new_pred=folding_candidate, old_pred_new_succ=succ)

    def _move_normalization(
        self, normalization_layer, old_pred, new_pred, old_pred_new_succ, norm_new_succ=None, detach=False
    ):
        """
        Move normalization layer to be after new pred layer.
        """
        if not detach:
            self._fuser_helper.replace_pred(normalization_layer, old_pred=old_pred, new_pred=new_pred)
            self._fuser_helper.replace_succ(new_pred, old_succ=old_pred_new_succ, new_succ=normalization_layer)
            self._fuser_helper.replace_succ(old_pred, old_succ=normalization_layer, new_succ=old_pred_new_succ)
            self._fuser_helper.replace_pred(old_pred_new_succ, old_pred=new_pred, new_pred=old_pred)
        else:
            norm_old_succs = list(self._model.successors(normalization_layer))
            self._fuser_helper.add_succs(normalization_layer, [norm_new_succ])
            for old_succ in norm_old_succs:
                self._fuser_helper.replace_pred(old_succ, old_pred=normalization_layer, new_pred=old_pred)
                self._fuser_helper.remove_succ(normalization_layer, old_succ)
            norm_new_succ = next(iter(self._model.successors(new_pred)))
            self._fuser_helper.replace_pred(normalization_layer, old_pred=old_pred, new_pred=new_pred)
            self._fuser_helper.replace_succ(new_pred, old_succ=norm_new_succ, new_succ=normalization_layer)
            self._fuser_helper.replace_pred(norm_new_succ, old_pred=new_pred, new_pred=normalization_layer)

    def _move_normalization_layers_to_second_degree_const_inputs_preds(self):
        """
        Move normalization layers to be after the second degree const inputs preds.
        """
        for layer in list(self._model):
            if layer.name in self._params and layer.op in (LayerType.dw, LayerType.normalization):
                if (
                    layer.dynamic_weights
                    or layer.activation != ActivationType.linear
                    or layer.strides != [1, 1, 1, 1]
                    or layer.dilations != [1, 1, 1, 1]
                ):
                    continue

                kernel = self._params[layer.name].get("kernel:0")
                k_h, k_w, _, _ = np.shape(kernel)[:]
                bias = self._params[layer.name].get("bias:0")

                if bias is not None and not np.all(bias == 0) or (k_h != 1 and k_w != 1):
                    continue

                preds = list(self._model.predecessors(layer))
                if len(preds) != 1:
                    continue

                pred = preds[0]
                if pred.op != LayerType.ew_add:
                    continue

                second_degree_preds = list(self._model.predecessors(pred))
                if len(second_degree_preds) != 2:
                    continue

                const_input_dict = {}
                is_fuseable = True
                for second_degree_pred in second_degree_preds:
                    if second_degree_pred.op != LayerType.ew_mult:
                        continue

                    third_degree_preds = list(self._model.predecessors(second_degree_pred))
                    if len(third_degree_preds) != 2:
                        is_fuseable = False
                        break

                    if not any(
                        third_degree_pred.op == LayerType.const_input for third_degree_pred in third_degree_preds
                    ):
                        is_fuseable = False
                        break

                    for third_degree_pred in third_degree_preds:
                        if third_degree_pred.op == LayerType.const_input and not self._can_fuse_const_input(
                            layer, third_degree_pred
                        ):
                            is_fuseable = False
                            break

                        if third_degree_pred.op == LayerType.const_input:
                            const_input_dict[second_degree_pred.name] = third_degree_pred
                            break

                if not is_fuseable or len(const_input_dict) != 2:
                    continue

                self._move_normalization(
                    normalization_layer=layer,
                    old_pred=pred,
                    new_pred=const_input_dict[second_degree_preds[0].name],
                    old_pred_new_succ=list(self._model.successors(layer)),
                    norm_new_succ=second_degree_preds[0],
                    detach=True,
                )

                duplicated_normalization = self._duplicate_normalization_layer(
                    reference_normalization=layer,
                    duplication_idx=0,
                )
                # attach the duplicated normalization layer to the remaining second degree const input pred
                self._fuser_helper.add_preds(duplicated_normalization, [const_input_dict[second_degree_preds[1].name]])
                self._fuser_helper.add_succs(
                    duplicated_normalization, [second_degree_preds[1]], update_output_shapes=False
                )
                self._fuser_helper.replace_succ(
                    const_input_dict[second_degree_preds[1].name], second_degree_preds[1], duplicated_normalization
                )
                self._fuser_helper.replace_pred(
                    second_degree_preds[1], const_input_dict[second_degree_preds[1].name], duplicated_normalization
                )
                self._update_new_norm_params(
                    duplicated_normalization.name,
                    layer,
                    None,
                    None,
                    duplicate_params=True,
                )

    def _fold_pre_layer_normalization_layers(self):
        """
        Allow forward folding of normalizations to conv/dw/deconv only when kernel is 1x1,
        otherwise the new bias calculation is very long and inefficient, and because paddings cannot be supported
        """
        layers_to_remove = []

        for layer in list(self._model):
            if layer.op in [
                LayerType.normalization,
                LayerType.dw,
            ]:  # , LayerType.batch_norm]: # https://hailotech.atlassian.net/browse/SDK-57993
                preds = list(self._model.predecessors(layer))
                succs = list(self._model.successors(layer))
                can_fuse, dst_layers = self._can_fuse_pre_layer_normalization(layer, preds, succs)
                if not can_fuse:
                    continue

                pred = next(iter(self._model.predecessors(layer)))
                succs = list(self._model.successors(layer))

                for succ in succs:
                    if layer.name in pred.outputs:
                        pred.replace_output_shape(layer.name, layer.output_shape)
                        pred.replace_output_index(layer.index, succ.index)
                        pred.replace_output_layer(layer.name, succ.name)
                    else:
                        pred.append_output_shape(layer.output_shape)
                        pred.append_output_index(succ.index)
                        pred.append_output_layer(succ.name)

                    index = pred.outputs.index(succ.name)
                    succ.replace_input_shape(layer.name, pred.output_shapes[index])
                    succ.replace_input_index(layer.index, pred.index)
                    succ.replace_input_layer(layer.name, pred.name)

                    self._model.remove_edge(layer, succ)
                    self._model.add_edge(pred, succ)

                src_params = self._params[layer.name]
                for dst_layer in dst_layers:
                    # forward fold normalization params (mean, std) over successor kernel/bias
                    self._fold_bn_or_normalization_params(layer, dst_layer, src_params, pre_layer_bn=True)
                    self.layers_with_folded_bn.append(dst_layer.name)

                    for name in layer.original_names[::-1]:
                        dst_layer.add_original_name(name, reverse_insertion=True)

                self._model.remove_edge(pred, layer)
                layers_to_remove.append(layer)
                self._logger.debug(
                    f"Folded {layer.op.value} layer {layer.name} onto it's successors {', '.join(layer.outputs)}",
                )

        for layer_to_remove in layers_to_remove:
            self._model.remove_layer(layer_to_remove)

        self._params = ModelParams(
            batch_norm_rescale_params(model=self._model, params=self.params, keep_normalization_params=False),
        )

    def _can_fuse_pre_layer_normalization(self, layer, preds, succs):
        if len(preds) > 1:
            return False, []

        if (
            layer.activation != ActivationType.linear
            or (layer.op != LayerType.batch_norm and layer.bn_enabled)
            or layer.dynamic_weights
        ):
            return False, []

        if layer.dynamic_weights or layer.transpose_output_width_features or layer.spatial_flatten_output:
            return False, []

        if (hasattr(layer, "kernel_height") and layer.kernel_height != 1) or (
            hasattr(layer, "kernel_width") and layer.kernel_width != 1
        ):
            return False, []

        if layer.strides != [1, 1, 1, 1] or layer.dilations != [1, 1, 1, 1]:
            return False, []

        dst_layers = []
        for succ in succs:
            succ_preds = list(self._model.predecessors(succ))
            if len(succ_preds) > 1:
                # cant fold forward to a layer with more than one input, since we can't fold for a single input
                return False, []

            if self._is_spatial_reshape(succ):
                succs.extend(self._model.successors(succ))
                continue

            if succ.op not in [LayerType.conv, LayerType.dw, LayerType.deconv, LayerType.dense]:
                return False, []

            if succ.op == LayerType.dense and succ.kernel_shape[0] != layer.kernel_shape[2]:
                return False, []

            if succ.op != LayerType.dense and succ.layer_disparity > 1:
                return False, []

            if succ.dynamic_weights:
                return False, []

            if succ.kernel_width != 1 or succ.kernel_height != 1:
                return False, []

            if succ.name in self.layers_with_folded_bn:
                return False, []

            dst_layers.append(succ)

        return True, dst_layers

    def _fold_post_layer_normalization_layers(self):
        """
        Allow basic backward folding of normalization layers to conv/dw/deconv/dense.
        """
        layers_to_remove = []
        for layer in list(self._model):
            if layer.op in [LayerType.normalization, LayerType.dw]:
                succs = list(self._model.successors(layer))
                preds = list(self._model.predecessors(layer))

                if not self._can_fuse_post_layer_normalization(layer, preds, is_multi_pred_layer=len(preds) > 1):
                    continue

                pred = preds[0]
                if self._is_spatial_reshape(pred):
                    dst_layer = next(iter(self._model.predecessors(pred)))
                else:
                    dst_layer = pred
                    dst_layer.transpose_output_width_features = layer.transpose_output_width_features
                    dst_layer.spatial_flatten_output = layer.spatial_flatten_output

                # in case dst activation is not linear its params will be menipulated instead of changing it.
                if dst_layer.op != LayerType.const_input:
                    if dst_layer.activation == ActivationType.linear:
                        dst_layer.activation = layer.activation

                    if layer.bn_enabled:
                        dst_layer.bn_enabled = dst_layer.bn_enabled

                for name in layer.original_names:
                    dst_layer.add_original_name(name)

                for succ in succs:
                    if layer.name in pred.outputs:
                        pred.replace_output_shape(layer.name, layer.output_shape)
                        pred.replace_output_index(layer.index, succ.index)
                        pred.replace_output_layer(layer.name, succ.name)
                    else:
                        pred.append_output_shape(layer.output_shape)
                        pred.append_output_index(succ.index)
                        pred.append_output_layer(succ.name)

                    succ.replace_input_shape(layer.name, pred.output_shape)
                    succ.replace_input_index(layer.index, pred.index)
                    succ.replace_input_layer(layer.name, pred.name)

                    self._model.remove_edge(layer, succ)
                    self._model.add_edge(pred, succ)

                if pred.op == LayerType.reduce_mean:
                    # move params from normalization layer to predecessor reduce mean
                    self._fold_normalization_as_reduce_mean_bias(layer, dst_layer)
                else:
                    # move params from batch norm/normalization layer to predecessor conv/dense
                    self._fold_bn_or_normalization_params(
                        layer,
                        dst_layer,
                        self._params[layer.name],
                        pre_layer_bn=False,
                    )
                self._model.remove_edge(pred, layer)

                for i, output in enumerate(self._model.net_params.output_layers_order):
                    if layer.name == output:
                        self._model.net_params.output_layers_order[i] = pred.name

                self._logger.debug(f"Folded {layer.op.value} layer {layer.name} onto predecessor layer {pred.name}")
                layers_to_remove.append(layer)

        for layer_to_remove in layers_to_remove:
            self._model.remove_layer(layer_to_remove)

    def _fold_post_layer_batch_norm_layers(self):
        """
        Allow basic folding of batch norm layers to conv/dw/deconv/dense.
        Assume weights are later folded as well in the native NPZ.
        """
        layers_to_remove = []
        for layer in list(self._model):
            if layer.op == LayerType.batch_norm:
                if not self._can_fuse_batch_norm(layer):
                    if len(layer.input_shape) == 2:
                        raise NormalizationOptimizingException(
                            f"{layer.full_name_msg} is a standalone batch norm in rank 2 which is not supported",
                        )

                    self._bn_info.update(
                        {
                            hn_to_npz_key(layer.name, k): v
                            for k, v in self._params[layer.name].items()
                            if is_bn_info_param_key(k)
                        },
                    )
                    continue

                succs = list(self._model.successors(layer))
                pred = next(iter(self._model.predecessors(layer)))
                is_pred_spatial_reshape = self._is_spatial_reshape(pred)

                dst_layer = next(iter(self._model.successors(layer))) if is_pred_spatial_reshape else pred
                if dst_layer.name in self.layers_with_folded_bn:
                    continue

                if is_pred_spatial_reshape:
                    dst_layer.transpose_output_width_features = layer.transpose_output_width_features
                    dst_layer.spatial_flatten_output = layer.spatial_flatten_output

                dst_layer.activation = layer.activation
                dst_layer.bn_enabled = True
                for name in layer.original_names:
                    dst_layer.add_original_name(name)

                for succ in succs:
                    if layer.name in pred.outputs:
                        pred.replace_output_index(layer.index, succ.index)
                        pred.replace_output_layer(layer.name, succ.name)
                        pred.replace_output_shape(layer.name, layer.output_shape)
                    else:
                        pred.append_output_index(succ.index)
                        pred.append_output_layer(succ.name)
                        pred.append_output_shape(layer.output_shape)

                    succ.replace_input_index(layer.index, pred.index)
                    succ.replace_input_layer(layer.name, pred.name)
                    succ.replace_input_shape(layer.name, pred.output_shape)

                    self._model.remove_edge(layer, succ)
                    self._model.add_edge(pred, succ)

                # move params from batch norm/normalization layer to predecessor conv/dense
                self._fold_bn_or_normalization_params(layer, dst_layer, self._params[layer.name], pre_layer_bn=False)
                self.layers_with_folded_bn.append(dst_layer.name)
                self._model.remove_edge(pred, layer)

                for i, output in enumerate(self._model.net_params.output_layers_order):
                    if layer.name == output:
                        self._model.net_params.output_layers_order[i] = pred.name

                self._logger.debug(f"Folded {layer.op.value} layer {layer.name} onto predecessor layer {pred.name}")
                layers_to_remove.append(layer)

        for layer_to_remove in layers_to_remove:
            self._model.remove_layer(layer_to_remove)

    def _fold_bn_or_normalization_params(self, src_layer, dst_layer, src_params, pre_layer_bn=False):
        dst_params = self._params[dst_layer.name]

        new_params = {
            hn_to_npz_key(dst_layer.name, k): v for k, v in src_params.items() if k not in ["kernel:0", "bias:0"]
        }

        if src_layer.op == LayerType.batch_norm:
            # For BN we just bring back native batch norm params, batch_norm_rescale_params will use it
            self._bn_info.update(
                {hn_to_npz_key(dst_layer.name, k): v for k, v in src_params.items() if is_bn_info_param_key(k)},
            )
            dst_layer.pre_layer_bn = bool(pre_layer_bn)
        elif dst_layer.op == LayerType.const_input:
            # need to convert normalization params for backward folding over const_data
            kernel = src_params["kernel:0"].flatten().astype(np.float128)
            bias = src_params["bias:0"].flatten().astype(np.float128)
            if dst_layer.input_tiles[0][2] != 1:
                kernel = np.unique(src_params["kernel:0"], axis=2).flatten().astype(np.float128)
                bias = np.unique(src_params["bias:0"]).flatten().astype(np.float128)
            new_params = {hn_to_npz_key(dst_layer.name, k): v for k, v in dst_params.items() if k != "const_data:0"}
            const_data = dst_params["const_data:0"].astype(np.float128)
            const_data = const_data * kernel + bias
            new_params.update({hn_to_npz_key(dst_layer.name, "const_data"): const_data.astype(np.float32)})
        else:
            # need to convert normalization params for backward folding over kernel/bias
            kernel = src_params["kernel:0"].flatten().astype(np.float128)
            bias = src_params["bias:0"].flatten().astype(np.float128)

            dst_kernel = dst_params["kernel:0"]
            dst_bias = dst_params["bias:0"]
            new_kernel = np.copy(dst_kernel).astype(np.float128)
            new_bias = np.copy(dst_bias).astype(np.float128)

            kernel_key = self._modifications_meta_data.add_modification_param("kernel", src_params["kernel:0"])
            bias_key = self._modifications_meta_data.add_modification_param("bias", src_params["bias:0"])

            # Forward folding normalization is only supported for 1x1 kernels, due to the high time consumption
            # of the inefficient calculation of the new bias, and problem with supporting paddings
            if pre_layer_bn:
                if dst_layer.op != LayerType.dense:
                    k_h, k_w, f_in, f_out = np.shape(new_kernel)[:]
                    new_kernel_shape = [1, 1, f_in, 1]
                else:
                    f_in, f_out = np.shape(new_kernel)[:]
                    k_h, k_w = 1, 1
                    new_kernel_shape = [f_in, 1]
                if k_h != 1 and k_w != 1:
                    raise NormalizationOptimizingException(
                        f"Cannot forward fold normalization layer {src_layer.name} to "
                        f"convolutional layer {dst_layer.name} that has kernel != 1x1.",
                    )
                new_bias = new_bias + np.matmul(bias, new_kernel.reshape([f_in, f_out])).reshape([f_out])
                new_kernel *= kernel.reshape(new_kernel_shape)
                self._modifications_meta_data.append(
                    dst_layer.name,
                    FoldNormalizationTracker(
                        kernel_key=kernel_key,
                        bias_key=bias_key,
                        apply_on_input=True,
                    ),
                )
            else:
                new_bias = new_bias * kernel + bias
                if dst_layer.op in [LayerType.dw, LayerType.normalization]:
                    kernel = kernel.reshape([1, 1, kernel.shape[0], 1])
                new_kernel *= kernel
                self._modifications_meta_data.append(
                    dst_layer.name,
                    FoldNormalizationTracker(
                        kernel_key=kernel_key,
                        bias_key=bias_key,
                        apply_on_input=False,
                    ),
                )

            new_params.update(
                {
                    hn_to_npz_key(dst_layer.name, "kernel"): new_kernel.astype(np.float32),
                    hn_to_npz_key(dst_layer.name, "bias"): new_bias.astype(np.float32),
                },
            )
            if (
                not pre_layer_bn
                and src_layer.activation == ActivationType.linear
                and dst_layer.activation == ActivationType.clip
            ):
                new_params.update(
                    {
                        hn_to_npz_key(dst_layer.name, "clip_min"): np.array(dst_params["clip_min:0"] * kernel.item(0)),
                        hn_to_npz_key(dst_layer.name, "clip_max"): np.array(dst_params["clip_max:0"] * kernel.item(0)),
                    },
                )

        self._params.update(new_params)
        self._params.remove(src_layer.name)

    def _can_fuse_batch_norm(self, layer, is_multi_pred_layer=False):
        preds = list(self._model.predecessors(layer))
        if not is_multi_pred_layer and len(preds) > 1:
            return False

        if layer.ew_add_enabled:
            return False

        if layer.dynamic_weights:
            return False

        for pred in preds:
            if (
                self._is_spatial_reshape(pred)
                and not (layer.transpose_output_width_features or layer.spatial_flatten_output)
                and len(list(self._model.successors(pred))) == 1
            ):
                pred = next(iter(self._model.predecessors(pred)))

            if pred.op not in [LayerType.conv, LayerType.dw, LayerType.deconv, LayerType.dense]:
                return False

            if pred.dynamic_weights or pred.transpose_output_width_features or pred.spatial_flatten_output:
                return False

            # can't fold BN over a layer that has more than one output (the second output expects non-normalized tensor)
            if len(list(self._model.successors(pred))) > 1:
                return False

            if pred.activation != ActivationType.linear or pred.bn_enabled or pred.ew_add_enabled:
                return False

            if pred.op != LayerType.dense and pred.layer_disparity > 1:
                return False

        return True

    def _can_fuse_const_input(self, normalization_layer, const_input_layer):
        """
        fusing to const input is supported only for untiled const inputs or when the const input is tiled and
        normalization kernel and bias can be squeezeed to match the tiled shape
        """
        kernel_unique_shape = np.unique(self._params[normalization_layer.name].get("kernel:0"), axis=2).shape
        bias_unique_shape = np.unique(self._params[normalization_layer.name].get("bias:0")).shape
        return not (
            const_input_layer.input_tiles[0][2] != 1 and (kernel_unique_shape[2] != 1 or bias_unique_shape[0] != 1)
        )

    def _can_fuse_post_layer_normalization(self, layer, preds, is_multi_pred_layer=False):
        if not is_multi_pred_layer and len(preds) > 1:
            return False

        if layer.ew_add_enabled:
            return False

        if layer.dynamic_weights:
            return False

        if layer.kernel_height != 1 or layer.kernel_width != 1:
            return False

        if layer.strides != [1, 1, 1, 1] or layer.dilations != [1, 1, 1, 1]:
            return False

        for pred in preds:
            if pred.op == LayerType.const_input:
                if not self._can_fuse_const_input(layer, pred):
                    return False
                continue
            if (
                self._is_spatial_reshape(pred)
                and not (layer.transpose_output_width_features or layer.spatial_flatten_output)
                and len(list(self._model.successors(pred))) == 1
            ):
                pred = next(iter(self._model.predecessors(pred)))

            if pred.op not in [
                LayerType.conv,
                LayerType.dw,
                LayerType.deconv,
                LayerType.dense,
                LayerType.normalization,
                LayerType.reduce_mean,
            ]:
                return False

            if pred.op == LayerType.reduce_mean:
                kernel = self._params[hn_to_npz_key(layer.name, "kernel")]
                if np.any(kernel != 1):
                    return False

            if pred.op in [LayerType.conv, LayerType.dw, LayerType.deconv] and pred.layer_disparity > 1:
                return False

            if pred.dynamic_weights or pred.transpose_output_width_features or pred.spatial_flatten_output:
                return False

            if layer.bn_enabled and pred.op == LayerType.normalization:
                return False

            # can't fold normalization over a layer that has more than one output (the second output expects
            # non-normalized tensor)
            if len(list(self._model.successors(pred))) > 1:
                return False

            if pred.activation != ActivationType.linear:
                if layer.activation == ActivationType.linear and pred.activation in [
                    ActivationType.linear,
                    ActivationType.relu,
                    ActivationType.clip,
                ]:
                    kernel = self._params[layer.name].get("kernel:0")
                    bias = self._params[layer.name].get("bias:0")
                    if (
                        kernel is not None
                        and bias is not None
                        and np.all(kernel == np.unique(kernel))  # all values are the same
                        and np.all(kernel > 0)  # all values are positive
                        and np.all(bias == 0)  # no bias
                    ):
                        return True
                return False

            if pred.bn_enabled or pred.ew_add_enabled:
                return False

        return True

    def _fold_normalization_as_reduce_mean_bias(self, src_layer, dst_layer):
        dst_bias_name = hn_to_npz_key(dst_layer.name, "bias")
        src_bias = self._params[hn_to_npz_key(src_layer.name, "bias")]
        dst_bias = self._params.get(dst_bias_name, np.array(0.0, dtype=np.float32))
        new_bias = src_bias + dst_bias
        new_params = {dst_bias_name: new_bias.astype(np.float32)}
        self._params.update(new_params)
        self._params.remove(src_layer.name)

    def _is_spatial_reshape(self, layer):
        return layer.op == LayerType.format_conversion and layer.conversion_type in [
            FormatConversionType.transpose_height_width,
            FormatConversionType.spatial_reshape,
        ]

    def _switch_norm_slice_to_slice_norm(self):
        """
        This function switches the order of normalization and feature slice layers to use less resources
        for the normalization layer.
        Normalization -> FeatureSlicer to FeatureSlicer -> Normalization
        """
        for slice_layer in list(self._model):
            if slice_layer.op == LayerType.slice:
                preds = list(self._model.predecessors(slice_layer))
                if (
                    len(preds) != 1
                    or preds[0].op != LayerType.normalization
                    or len(list(self._model.successors(preds[0]))) != 1
                ):
                    continue

                normalization_layer = preds[0]
                self._fuser_helper.swap_layers_order(normalization_layer, slice_layer, is_layer1_first=True)
                self._update_new_norm_params(normalization_layer.name, normalization_layer, slice_layer)

    def _update_new_norm_params(
        self,
        new_normalization_layer_name,
        ref_normalization_layer,
        reference_layer,
        relative_input_idx=None,
        duplicate_params=False,
    ):
        normalization_params = self._params[ref_normalization_layer.name]
        new_normalization_params = {
            f"{new_normalization_layer_name}/{k}": v
            for k, v in normalization_params.items()
            if k not in ["kernel:0", "bias:0"]
        }

        original_kernel_shape = normalization_params["kernel:0"].shape
        original_bias_shape = normalization_params["bias:0"].shape
        kernel = np.copy(normalization_params["kernel:0"].flatten().astype(np.float128)).astype(np.float128)
        bias = np.copy(normalization_params["bias:0"].flatten().astype(np.float128)).astype(np.float128)

        if duplicate_params:
            new_kernel = kernel
            new_kernel = new_kernel.reshape(original_kernel_shape)
            new_bias = bias
            new_bias = new_bias.reshape(original_bias_shape)
        elif reference_layer.op == LayerType.slice:
            new_kernel = kernel[reference_layer.features_slice[0] : reference_layer.features_slice[1]]
            new_kernel_len = reference_layer.features_slice[1] - reference_layer.features_slice[0]
            new_kernel = new_kernel.reshape((1, 1, new_kernel_len, 1))
            new_bias = bias[reference_layer.features_slice[0] : reference_layer.features_slice[1]]
        elif reference_layer.op == LayerType.concat:
            new_kernel = kernel.reshape(original_kernel_shape)
            new_bias = bias.reshape(original_bias_shape)
            if (
                CONCAT_AXIS_TO_DIM[reference_layer.axis] == 3
            ):  # concating on features requires manipulating the normalization data respectively to the number of inputs
                new_kernel_len = reference_layer.output_shapes[0][CONCAT_AXIS_TO_DIM[reference_layer.axis]] // len(
                    reference_layer.input_list
                )
                new_kernel = kernel[relative_input_idx * new_kernel_len : (relative_input_idx + 1) * new_kernel_len]
                new_kernel = new_kernel.reshape([1, 1, new_kernel_len, 1])
                new_bias = bias[relative_input_idx * new_kernel_len : (relative_input_idx + 1) * new_kernel_len]
        elif reference_layer.op in [LayerType.ew_mult, LayerType.ew_div]:
            new_bias = bias.reshape(original_bias_shape)
            if reference_layer.op == LayerType.ew_mult:
                new_kernel = (np.sqrt(kernel)).reshape(original_kernel_shape)
            else:
                new_kernel = (np.power(kernel, 2)).reshape(original_kernel_shape)
        elif reference_layer.op in [LayerType.ew_add, LayerType.ew_sub]:
            new_kernel = kernel.reshape(original_kernel_shape)
            new_bias = (bias / len(reference_layer.input_list)).reshape(original_bias_shape)
            if reference_layer.op == LayerType.ew_sub and relative_input_idx != 0:
                new_bias *= -1

        new_normalization_params.update(
            {
                "kernel:0": new_kernel.astype(np.float32),
                "bias:0": new_bias.astype(np.float32),
            },
        )
        new_normalization_params_full_names = {
            f"{new_normalization_layer_name}/{k}": v for k, v in new_normalization_params.items()
        }

        self._params.update(new_normalization_params_full_names)
        if new_normalization_layer_name not in self._params.layers:
            self._params = ModelParams(self._params.params)
