from enum import Enum

import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import PostprocessTarget, PostprocessType
from hailo_sdk_client.model_translator.exceptions import UnsupportedPostprocessLayerError
from hailo_sdk_client.model_translator.graph_lookup import BwdChainNode, FwdChainNode, get_all_nodes_in_chain
from hailo_sdk_common.hailo_nn.hailo_nn import HailoNN
from hailo_sdk_common.hailo_nn.hn_definitions import (
    ActivationType,
    LayerType,
    PaddingType,
    ResizeBilinearPixelsMode,
    ResizeMethod,
)
from hailo_sdk_common.hailo_nn.hn_layers import (
    ActivationLayer,
    EWAddLayer,
    EWMultLayer,
    FusedConv2DLayer,
    FusedStandaloneActivationLayer,
    FusedStandaloneEWAddLayer,
    OutputLayer,
    PostprocessLayer,
    ResizeLayer,
    ShortcutLayer,
)
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.tools.models_translator_helper import is_spatial_broadcast

logger = default_logger()


class PostprocessAdditionMode(Enum):
    ADD_AS_OUTPUT = "add_as_output"
    MAPPING_FROM_NN_CORE_TO_CPU = "mapping_from_nn_core_to_cpu"


class FuserHelper:
    def __init__(self, model):
        self._model = model

    @property
    def model(self):
        return self._model

    @property
    def layers_with_activation(self):
        return self._layers_with_activation

    @model.setter
    def model(self, model):
        self._model = model

    @staticmethod
    def add_logits_as_postprocess_layer_to_hn(
        graph,
        layers,
        activation_type,
        axis,
        logits_layer_names=None,
        remove_succ_output=True,
        postprocess_addition_mode=PostprocessAdditionMode.ADD_AS_OUTPUT,
    ):
        new_logits_layers_name = []
        for i, layer in enumerate(layers):
            if layer.op == LayerType.output_layer:
                raise UnsupportedPostprocessLayerError(
                    f"The logits layer command can be applied only on a real model "
                    f"node. Please change the layer name from {layer.name} to "
                    f"{next(iter(graph.predecessors(layer))).name_without_scope}",
                )
            postprocess_layer = PostprocessLayer()
            postprocess_layer.type = activation_type
            if axis not in [-1, 3]:
                raise UnsupportedPostprocessLayerError(
                    f"Unsupported axis value {axis}. Logits postprocess layer can "
                    f"be applied only on the channels dimension",
                )
            postprocess_layer.axis = 3
            postprocess_layer.engine = PostprocessTarget.CPU
            postprocess_layer.op = LayerType.postprocess
            postprocess_layer.postprocess_type = PostprocessType.LOGITS
            postprocess_layer.index = graph.get_next_index()
            postprocess_layers_count = len([layer for layer in graph if layer.op == LayerType.postprocess])
            postprocess_layer.name = (
                f"{layer.scope}/{postprocess_layer.type.value}_logits_postprocess{postprocess_layers_count + 1}"
                if not logits_layer_names
                else logits_layer_names[i]
            )
            new_logits_layers_name.append(postprocess_layer.name)
            successors = list(graph.successors(layer))
            index_of_output_layer = None
            for i, succ in enumerate(successors):
                if (
                    postprocess_addition_mode == PostprocessAdditionMode.ADD_AS_OUTPUT
                    and succ.op == LayerType.output_layer
                ) or (
                    postprocess_addition_mode == PostprocessAdditionMode.MAPPING_FROM_NN_CORE_TO_CPU
                    and succ.op == activation_type
                ):
                    index_of_output_layer = i
                    break

            if index_of_output_layer is None:
                raise UnsupportedPostprocessLayerError(
                    f"Unable to find output layer for layer {layer.name} when trying to add {postprocess_layer.name}",
                )

            # softmax doesn't change the output shape, argmax flatten the axis dimension
            if activation_type == LayerType.softmax:
                output_shapes = [layer.output_shapes[index_of_output_layer]]
            else:
                output_shapes = [
                    [
                        dim if i != postprocess_layer.axis else 1
                        for i, dim in enumerate(layer.output_shapes[index_of_output_layer])
                    ],
                ]
            postprocess_layer.output_shapes = output_shapes
            layer.outputs.append(postprocess_layer.name)
            graph.add_edge(layer, postprocess_layer)
            postprocess_layer.input_shapes.extend(layer.output_shapes)
            postprocess_layer.inputs.append(layer.name)

            # creates postprocess output layers
            postprocess_output_layer = OutputLayer()
            postprocess_output_layer.name = f"{postprocess_layer.name}_output_layer"
            postprocess_output_layer.engine = PostprocessTarget.CPU
            postprocess_output_layer.index = graph.get_next_index()
            postprocess_output_layer.input_shapes = output_shapes
            postprocess_output_layer.output_shapes = output_shapes

            if remove_succ_output:
                # removing nn_core output layers for inserting the postprocess output layer as the only one
                output_layer = [successor for successor in successors if successor.op == LayerType.output_layer]
                if len(output_layer) > 0:
                    output_layer = output_layer[0]
                    layer.outputs.remove(output_layer.name)
                    graph.remove_edge(layer, output_layer)
                    graph.remove_node(output_layer)
                    output_index = graph.net_params.output_layers_order.index(layer.name)
                    graph.net_params.output_layers_order[output_index] = postprocess_layer.name

            # connects output layer to postprocess layer
            graph.add_node(postprocess_output_layer)
            graph.add_edge(postprocess_layer, postprocess_output_layer)
            postprocess_output_layer.inputs.extend([postprocess_layer.name])
            postprocess_output_layer.input_indices.extend([postprocess_layer.index])
            postprocess_layer.outputs.append(postprocess_output_layer.name)

        return new_logits_layers_name

    def run_broadcast_ew(self, layers=None):
        new_layers = []
        successors_meta_data = {}
        layers = layers if layers else list(self.model)
        for layer in layers:
            if layer.op in [
                LayerType.base_ew_add,
                LayerType.base_ew_sub,
                LayerType.ew_sub,
                LayerType.ew_mult,
                LayerType.ew_div,
            ] or (layer.op == LayerType.base_activation and layer.activation == ActivationType.inv_pos):
                preds = list(self.model.predecessors(layer))

                if len(preds) != 2:
                    continue

                index = preds[0].outputs.index(layer.name)
                pred0_shape = preds[0].output_shapes[index]

                index = preds[1].outputs.index(layer.name)
                pred1_shape = preds[1].output_shapes[index]

                does_apply_broadcast = FuserHelper.is_feature_broadcast(
                    pred0_shape,
                    pred1_shape,
                    is_two_sided=True,
                ) or is_spatial_broadcast(pred0_shape, pred1_shape, is_two_sided=True)

                if does_apply_broadcast:
                    for dim in range(1, len(pred0_shape)):
                        # handle edge case for groups
                        shape_to_repeat, index_to_repeat = (pred1_shape, 1)
                        if pred0_shape[dim] == 1 or (
                            hasattr(preds[0], "groups")
                            and ((dim == 3 and pred0_shape[dim] == preds[0].groups) or (pred0_shape[dim] == 1))
                        ):
                            shape_to_repeat, index_to_repeat = (pred0_shape, 0)

                        output_shape = layer.input_shapes[1 - index_to_repeat]
                        self.update_input_repeats(layer, shape_to_repeat, output_shape, dim)

        for layer in new_layers:
            self.model.relax_new_layer_into_graph(layer, successors_meta_data)

    def replace_spatial_input_repeats_with_resize(self):
        new_layers = []
        successors_meta_data = {}
        for layer in list(self._model):
            if (
                layer.op
                in [
                    LayerType.base_ew_add,
                    LayerType.base_ew_sub,
                    LayerType.ew_sub,
                    LayerType.ew_mult,
                    LayerType.ew_div,
                ]
            ) and layer.input_repeats != [[1, 1, 1], [1, 1, 1]]:
                preds = list(self.model.predecessors(layer))
                if len(preds) != 2:
                    continue

                index = preds[0].outputs.index(layer.name)
                pred0_shape = preds[0].output_shapes[index]

                index = preds[1].outputs.index(layer.name)
                pred1_shape = preds[1].output_shapes[index]

                input_repeats = layer.input_repeats
                does_apply_resize_on_pred0 = any(ratio > 1 for ratio in input_repeats[0][:-1])
                does_apply_resize_on_pred1 = any(ratio > 1 for ratio in input_repeats[1][:-1])
                if does_apply_resize_on_pred0 and does_apply_resize_on_pred1:
                    # dual broadcast in ew operation [N,H,1,C] + [N,1,W,C] -> [N,H,W,C]
                    # apply resize on pred0 and pred1
                    output_shape = [max(dims) for dims in zip(pred0_shape, pred1_shape)]
                    for i in range(len(preds)):
                        is_reused_broadcast = self.is_reused_broadcast(preds[i], layer, output_shape)
                        if not is_reused_broadcast:
                            self.add_broadcast_resize(
                                preds[i],
                                layer,
                                new_layers,
                                successors_meta_data,
                                pred0_shape if i == 0 else pred1_shape,
                                output_shape,
                                is_spatial=True,
                            )
                        layer.input_repeats[i][:-1] = [1, 1]

                elif does_apply_resize_on_pred0 or does_apply_resize_on_pred1:
                    index_of_pred = 0 if does_apply_resize_on_pred0 else 1
                    # regular broadcast
                    pred = preds[index_of_pred]
                    pred_shape = pred0_shape if does_apply_resize_on_pred0 else pred1_shape
                    pred_neighbor_shape = pred1_shape if does_apply_resize_on_pred0 else pred0_shape
                    output_shape = [*pred_neighbor_shape[:-1], pred_shape[-1]]
                    is_reused_broadcast = self.is_reused_broadcast(pred, layer, output_shape)
                    if not is_reused_broadcast:
                        self.add_broadcast_resize(
                            pred,
                            layer,
                            new_layers,
                            successors_meta_data,
                            pred_shape,
                            output_shape,
                            is_spatial=True,
                        )
                    layer.input_repeats[index_of_pred][:-1] = [1, 1]

        for layer in new_layers:
            self.model.relax_new_layer_into_graph(layer, successors_meta_data)

    @staticmethod
    def is_feature_broadcast(layer_shape, neighbor_shape, is_two_sided=False):
        if layer_shape[-1] == neighbor_shape[-1]:
            return False

        if (not is_two_sided and neighbor_shape[-1] % layer_shape[-1] != 0) or (
            is_two_sided and layer_shape[-1] % neighbor_shape[-1] != 0 and neighbor_shape[-1] % layer_shape[-1] != 0
        ):
            return False

        if len(layer_shape) == 2 or len(neighbor_shape) == 2:
            return False

        if layer_shape[1] != neighbor_shape[1]:
            return False

        if layer_shape[2] != neighbor_shape[2]:
            return False

        return True

    @staticmethod
    def is_layer_in_masked_softmax_block(model, layer):
        """
        Check if the given layer is inside a softmax block.
        """
        if layer.op not in [LayerType.base_ew_sub, LayerType.ew_mult, LayerType.matmul]:
            return False

        softmax_types = [
            LayerType.normalization,
            LayerType.reduce_max,
            LayerType.base_ew_sub,
            LayerType.base_activation,
            LayerType.ew_mult,
            LayerType.reduce_sum,
            LayerType.base_activation,
            LayerType.ew_mult,
        ]
        softmax_fused_layers = [
            LayerType.reduce_max,
            LayerType.ew_sub,
            LayerType.ew_mult,
            LayerType.reduce_sum,
            LayerType.ew_mult,
            LayerType.matmul,
        ]

        if layer.op == LayerType.base_ew_sub:
            fwd_types = softmax_types[softmax_types.index(layer.op) + 1 :]
            bwd_types = softmax_types[: softmax_types.index(layer.op)][::-1]
            fwd_chain = get_all_nodes_in_chain(model, layer, [FwdChainNode(op) for op in fwd_types], exact_match=True)
            bwd_chain = get_all_nodes_in_chain(model, layer, [BwdChainNode(op) for op in bwd_types], exact_match=True)
            return fwd_chain and bwd_chain
        elif layer.op == LayerType.matmul:
            fwd_types = softmax_fused_layers
            fwd_chain = get_all_nodes_in_chain(model, layer, [FwdChainNode(op) for op in fwd_types], exact_match=True)
            return fwd_chain
        else:
            # ew_mult is at the end of the softmax block no need to check fwd
            bwd_types = softmax_types[:-1][::-1]
            bwd_chain = get_all_nodes_in_chain(model, layer, [BwdChainNode(op) for op in bwd_types], exact_match=True)
            return bwd_chain

    def is_reused_broadcast(self, pred, layer, output_shape):
        for succ in list(self.model.successors(pred)):
            if not (succ.op == LayerType.resize and succ.output_shapes[0] == output_shape):
                continue
            if pred.op == LayerType.feature_splitter:
                if (
                    pred.split_indices[pred.output_indices.index(layer.index)]
                    != pred.split_indices[pred.output_indices.index(succ.index)]
                ):
                    # different split indices, should have different resize layers
                    continue
                pred.split_indices.pop(pred.output_indices.index(layer.index))
            self.replace_pred(layer, pred, succ)
            self.add_succs(succ, [layer])
            self.remove_succ(pred, layer)
            return True
        return False

    def add_broadcast_resize(
        self,
        pred,
        layer,
        new_layers,
        successors_meta_data,
        input_shape,
        output_shape,
        is_spatial,
    ):
        resize = ResizeLayer()
        resize.index = self.model.get_next_index()
        scope = f"{layer.scope}/" if layer.scope else ""
        resize_type = "spatial" if is_spatial else "features"
        resize.name = f"{scope}resize_{resize_type}_{pred.name_without_scope}_{layer.name_without_scope}"
        resize.original_names = layer.original_names
        resize.input_shape = input_shape
        resize.output_shapes = [output_shape]
        resize.resize_method = ResizeMethod.nearest_neighbor
        resize.resize_bilinear_pixels_mode = ResizeBilinearPixelsMode.disabled
        resize.h_ratios = [float(output_shape[1] / input_shape[1])] if len(output_shape) > 2 else [1.0]
        resize.w_ratios = [float(output_shape[2] / input_shape[2])] if len(output_shape) > 2 else [1.0]
        resize.f_ratios = [float(output_shape[-1] / input_shape[-1])]
        resize.append_input_index(pred.index)
        resize.append_input_layer(pred.name)
        resize.block_info = layer.block_info

        self.model.add_node(resize)
        new_layers.append(resize)
        self.model.remove_edge(pred, layer)
        self.model.add_edge(pred, resize)
        self.model.add_edge(resize, layer)
        pred.replace_output_layer(layer.name, resize.name)
        pred.replace_output_index(layer.index, resize.index)
        layer.replace_input_shape(pred.name, output_shape)
        layer.replace_input_index(pred.index, resize.index)
        layer.replace_input_layer(pred.name, resize.name)
        resize.append_output_layer(layer.name)
        resize.append_output_index(layer.index)
        HailoNN.update_successors_meta_data(layer, successors_meta_data)

        logger.debug(
            f"Inserted a resize layer to enable broadcast elementwise action in layer {layer.name_without_scope}.",
        )

    def split_ew_add_n_layers(self, split_to_base_layers=True, act_dict=None):
        new_layers = []
        layers_to_remove = []
        successors_meta_data = {}
        ew_add_to_act_params = {}

        for layer in list(self.model):
            if layer.op == LayerType.base_ew_add_n:
                scope = f"{layer.scope}/" if layer.scope else ""
                preds = list(self.model.predecessors(layer))
                succs = list(self.model.successors(layer))
                base_idx = self.model.get_next_index()
                new_ew_adds = []
                curr_preds = preds[:]

                if split_to_base_layers:
                    MAX_EW_ADD_INPUTS = EWAddLayer().number_of_inputs_supported
                else:
                    MAX_EW_ADD_INPUTS = FusedStandaloneEWAddLayer().number_of_inputs_supported

                while len(curr_preds) > 1:
                    i = 1
                    n = len(curr_preds)
                    curr_ew_adds = []
                    while i < n + (MAX_EW_ADD_INPUTS - 1):
                        curr_ew_add = EWAddLayer() if split_to_base_layers else FusedStandaloneEWAddLayer()

                        index = len(new_ew_adds) + len(curr_ew_adds)
                        curr_ew_add.index = base_idx + (index)
                        curr_ew_add.name = f"{scope}ew_add{index}_{layer.name_without_scope}"

                        # use actual preds unless we got to reminder ew_add in an odd number of preds situation
                        if i < n:
                            inputs = curr_preds[i - (MAX_EW_ADD_INPUTS - 1) : i + 1]
                        else:
                            reminder = MAX_EW_ADD_INPUTS - (i - n + 1)
                            inputs = curr_ew_adds[-1:] if reminder < 2 else []
                            inputs += curr_preds[n - reminder : n]
                            new_ew_adds.append(curr_ew_adds.pop(-1))
                        curr_ew_add.inputs = [x.name for x in inputs]
                        curr_ew_add.input_indices = [x.index for x in inputs]
                        curr_ew_add.input_shapes = [x.output_shape for x in inputs]
                        curr_ew_add.output_shapes = layer.output_shapes
                        curr_ew_add.block_info = layer.block_info

                        self.model.add_node(curr_ew_add)
                        curr_ew_adds.append(curr_ew_add)

                        for inp in inputs:
                            self.model.add_edge(inp, curr_ew_add)
                            inp.replace_output_layer(layer.name, curr_ew_add.name)
                            inp.replace_output_index(layer.index, curr_ew_add.index)

                        i += MAX_EW_ADD_INPUTS

                    new_ew_adds += curr_ew_adds
                    curr_preds = curr_ew_adds

                for pred in preds:
                    self.model.remove_edge(pred, layer)

                # original names and successor data are assigned to the last ew_add
                last_ew_add = new_ew_adds[-1]
                last_ew_add.original_names = layer.original_names
                if act_dict is not None and layer.name in act_dict:
                    last_ew_add.activation = act_dict[layer.name][0]
                    ew_add_to_act_params[last_ew_add.name] = act_dict[layer.name][1]

                for succ in succs:
                    self.model.remove_edge(layer, succ)
                    self.model.add_edge(last_ew_add, succ)
                    succ.replace_input_layer(layer.name, last_ew_add.name)
                    succ.replace_input_index(layer.index, last_ew_add.index)
                    HailoNN.update_successors_meta_data(succ, successors_meta_data)

                new_layers.extend(new_ew_adds)
                layers_to_remove.append(layer)

                for i, output in enumerate(self.model.net_params.output_layers_order):
                    if layer.name == output:
                        self.model.net_params.output_layers_order[i] = last_ew_add.name

        # finalize graph manipulation outside toposort iteration
        if layers_to_remove:
            for layer in layers_to_remove:
                self.model.remove_layer(layer)

        for layer in new_layers:
            self.model.relax_new_layer_into_graph(layer, successors_meta_data)

        return ew_add_to_act_params

    def handle_ew_div(self, split_to_base_layers=True):
        """Split ew div to inv + ew_mult"""
        layers_to_remove = []
        new_layers = []
        successors_meta_data = {}

        for layer in list(self._model):
            if layer.op == LayerType.ew_div:
                scope = f"{layer.scope}/" if layer.scope else ""
                inv_pred_index = layer.input_indices[1]
                preds = list(self._model.predecessors(layer))

                if preds[0].index == inv_pred_index:
                    inv_pred, other_pred = preds
                else:
                    other_pred, inv_pred = preds

                ew_mult = EWMultLayer()
                ew_mult.index = self._model.get_next_index()
                ew_mult.name = f"{scope}ew_mult_{layer.name_without_scope}"
                ew_mult.inputs = layer.inputs
                ew_mult.input_indices = layer.input_indices
                ew_mult.outputs = layer.outputs
                ew_mult.output_indices = layer.output_indices
                ew_mult.input_shapes = layer.input_shapes
                ew_mult.output_shapes = layer.output_shapes
                ew_mult.move_params(layer)

                inv_layer = ActivationLayer() if split_to_base_layers else FusedStandaloneActivationLayer()
                inv_layer.index = ew_mult.index + 1
                inv_layer.name = f"{scope}activation_{layer.name_without_scope}"
                inv_layer.inputs = [inv_pred.name]
                inv_layer.input_indices = [inv_pred.index]
                inv_layer.input_shape = layer.input_shapes[1]
                inv_layer.output_shapes = [inv_layer.input_shape]
                inv_layer.activation = ActivationType.inv_pos
                inv_layer.outputs = [ew_mult.name]
                inv_layer.output_indices = [ew_mult.index]
                inv_layer.move_params(layer)

                self._model.add_node(inv_layer)
                self._model.add_node(ew_mult)

                self._model.add_edge(inv_layer, ew_mult)
                ew_mult.replace_input_layer(inv_pred.name, inv_layer.name)
                ew_mult.replace_input_index(inv_pred.index, inv_layer.index)

                self._model.remove_edge(other_pred, layer)
                self._model.add_edge(other_pred, ew_mult)
                other_pred.replace_output_layer(layer.name, ew_mult.name)
                other_pred.replace_output_index(layer.index, ew_mult.index)

                self._model.remove_edge(inv_pred, layer)
                self._model.add_edge(inv_pred, inv_layer)
                inv_pred.replace_output_layer(layer.name, inv_layer.name)
                inv_pred.replace_output_index(layer.index, inv_layer.index)

                HailoNN.update_successors_meta_data(layer, successors_meta_data)

                for succ in list(self._model.successors(layer)):
                    self._model.remove_edge(layer, succ)
                    self._model.add_edge(ew_mult, succ)
                    succ.replace_input_index(layer.index, ew_mult.index)
                    succ.replace_input_layer(layer.name, ew_mult.name)
                    HailoNN.update_successors_meta_data(succ, successors_meta_data)

                layers_to_remove.append(layer)
                new_layers.extend([ew_mult, inv_layer])

                logger.debug(f"Replaced EW Div layer {layer.name} with inv activation and ew mult layers")

        # finalize graph manipulation outside toposort iteration
        for layer in layers_to_remove:
            self._model.remove_layer(layer)

        for layer in new_layers:
            self._model.relax_new_layer_into_graph(layer, successors_meta_data)

        return new_layers

    @staticmethod
    def replace_dense_with_conv1x1(dense_layer, model, params=None):
        groups, strides, dilations = 1, [1, 1, 1, 1], [1, 1, 1, 1]
        output_shapes = [[output_shape[0], 1, 1, output_shape[-1]] for output_shape in dense_layer.output_shapes]
        input_shapes = [[in_shape[0], 1, 1, int(np.prod(in_shape[1:]))] for in_shape in dense_layer.input_shapes]

        kernel_key = f"{dense_layer.name}/kernel:0"
        kernel = dense_layer.kernel if not params else params.params[kernel_key]
        reshaped_kernel = np.reshape(kernel, [1, 1, kernel.shape[0], kernel.shape[-1]])

        if params:
            # update kernel in params dict
            params.update({kernel_key: reshaped_kernel})

        conv1x1_layer = FusedConv2DLayer()
        conv1x1_layer.original_names = dense_layer.original_names
        conv1x1_layer.index = model.get_next_index()
        conv1x1_layer.name = f"conv{conv1x1_layer.index}"

        conv1x1_layer.kernel = reshaped_kernel
        conv1x1_layer.kernel_shape = reshaped_kernel.shape
        conv1x1_layer.bias = dense_layer.bias.copy()
        conv1x1_layer.strides = strides
        conv1x1_layer.dilations = dilations
        conv1x1_layer.padding = PaddingType.valid
        conv1x1_layer.groups = groups
        conv1x1_layer.activation = dense_layer.activation

        conv1x1_layer.inputs = dense_layer.inputs.copy()
        conv1x1_layer.input_shapes = input_shapes.copy()
        conv1x1_layer.input_indices = dense_layer.input_indices.copy()

        conv1x1_layer.outputs = dense_layer.outputs.copy()
        conv1x1_layer.output_shapes = output_shapes.copy()
        conv1x1_layer.output_indices = dense_layer.output_indices.copy()

        model.add_node(conv1x1_layer)

        preds = list(model.predecessors(dense_layer))
        succs = list(model.successors(dense_layer))
        for pred in preds:
            model.add_edge(pred, conv1x1_layer)
            pred.replace_output_index(dense_layer.index, conv1x1_layer.index)
            pred.replace_output_shape(dense_layer.name, conv1x1_layer.input_shape)
            pred.replace_output_layer(dense_layer.name, conv1x1_layer.name)

            if pred.op == LayerType.input_layer:
                pred.input_shapes = pred.output_shapes

        for succ in succs:
            model.add_edge(conv1x1_layer, succ)
            succ.replace_input_index(dense_layer.index, conv1x1_layer.index)
            succ.replace_input_layer(dense_layer.name, conv1x1_layer.name)

        model.remove_layer(dense_layer)
        return conv1x1_layer

    def swap_layers_order(self, layer1, layer2, is_layer1_first=True):
        """
        This function switches the order of layer1 and layer2.
        is_layer1_first is True: Layer1 -> Layer2 to Layer2 -> Layer1.
        is_layer1_first is False: Layer2 -> Layer1 to Layer1 -> Layer2.
        """
        first_layer = layer1
        second_layer = layer2
        first_layer_new_input_shapes = second_layer.output_shapes.copy()
        first_layer_new_output_shapes = second_layer.output_shapes.copy()
        second_layer_new_input_shapes = first_layer.output_shapes.copy()
        second_layer_new_output_shapes = second_layer.output_shapes.copy()
        if not is_layer1_first:
            first_layer = layer2
            second_layer = layer1
            first_layer_new_input_shapes = next(iter(self._model.predecessors(first_layer))).output_shapes.copy()
            first_layer_new_output_shapes = first_layer.output_shapes.copy()
            second_layer_new_input_shapes = first_layer_new_input_shapes
            second_layer_new_output_shapes = second_layer_new_input_shapes.copy()

        self.remove_pred(second_layer, first_layer)
        self.remove_succ(first_layer, second_layer)
        first_layer_preds = list(self._model.predecessors(first_layer))
        for pred in first_layer_preds:
            self.add_succs(pred, [second_layer])
            self.remove_pred(first_layer, pred)
            self.remove_succ(pred, first_layer)
        self.add_preds(second_layer, first_layer_preds)
        first_layer.input_shapes = first_layer_new_input_shapes
        first_layer.output_shapes = first_layer_new_output_shapes
        second_layer_succs = list(self._model.successors(second_layer))
        for succ in second_layer_succs:
            self.replace_pred(succ, second_layer, first_layer)
            self.remove_succ(second_layer, succ)
        second_layer.input_shapes = second_layer_new_input_shapes
        second_layer.output_shapes = second_layer_new_output_shapes
        self.add_succs(first_layer, second_layer_succs)
        self.add_preds(first_layer, [second_layer], update_input_shapes=False)
        self.add_succs(second_layer, [first_layer])

        # Check if the original second layer was an output layer, if so, update the output layers order
        if second_layer.name in self._model.net_params.output_layers_order:
            second_layer_index = self._model.net_params.output_layers_order.index(second_layer.name)
            self._model.net_params.output_layers_order[second_layer_index] = first_layer.name

    def create_layer(self, layer_class, index, name, from_layer, new_layers, output_shapes, block_info=None):
        new_layer = layer_class()
        new_layer.index = index
        scope = f"{from_layer.scope}/" if from_layer.scope else ""
        new_layer.name = f"{scope}{from_layer.name_without_scope}_{name}"
        new_layer.output_shapes = output_shapes
        new_layer.move_params(from_layer)
        self.model.add_node(new_layer)
        new_layers.append(new_layer)
        if block_info is not None:
            new_layer.block_info = block_info

        return new_layer

    def replace_layer_with_shortcut(self, layer, new_layers=None, in_out_shortcut=False):
        """
        This method replaces the given layer with a shortcut layer.
        """
        preds = []
        if in_out_shortcut:
            # related to _add_shortcut_to_empty_model functionality
            # insert shortcut layer between input and output layers
            succs = list(self.model.successors(layer))
            preds = [layer]
            succs = [succ for succ in succs if succ.op == LayerType.output_layer]
            # detach in from out
            for succ in succs:
                layer.outputs.remove(succ.name)
                succ.inputs.remove(layer.name)
        else:
            new_layers = []
            succs = list(self.model.successors(layer))
            preds = list(self.model.predecessors(layer))

        for succ in succs:
            shortcut = self.create_layer(
                ShortcutLayer,
                self.model.get_next_index(),
                "shortcut",
                layer,
                new_layers,
                [layer.output_shape],
            )
            self.add_preds(shortcut, preds)
            self.add_succs(shortcut, [succ])
            self.replace_pred(succ, layer, shortcut)
            for pred in preds:
                self.replace_succ(pred, layer, shortcut)

            logger.debug(f"Replaced {layer.full_name_msg} with shortcut layer")

    def remove_layer(self, layer, layers_to_remove, fuse_to_succ=False, fuse_to_pred=True):
        def should_replace_with_shortcut(pred, layer, succ):
            # replace with shortcut instead of removing, if:
            # 1. pred is already connected to succ
            # 2. pred is connected to other outputs layers and layer is connected to other output layers
            if succ.name in pred.outputs:
                return True
            return any(
                pred_succ.op == LayerType.output_layer for pred_succ in list(self.model.successors(pred))
            ) and any(layer_succ.op == LayerType.output_layer for layer_succ in list(self.model.successors(layer)))

        # Assuming null layer have single predecessor
        pred = next(iter(self.model.predecessors(layer)))
        succs = list(self.model.successors(layer))
        for succ in succs:
            if should_replace_with_shortcut(pred, layer, succ):
                self.replace_layer_with_shortcut(layer)
                layers_to_remove.append(layer)
                return

        self.model.remove_edge(pred, layer)

        for succ in succs:
            if layer.name in pred.outputs:
                pred.replace_output_index(layer.index, succ.index)
                pred.replace_output_layer(layer.name, succ.name)
            else:
                pred.append_output_index(succ.index)
                pred.append_output_layer(succ.name)
                pred.append_output_shape(pred.output_shape)

            succ.replace_input_shape(layer.name, layer.input_shapes[0])
            succ.replace_input_index(layer.index, pred.index)
            succ.replace_input_layer(layer.name, pred.name)
            if fuse_to_succ:
                for name in layer.original_names:
                    succ.add_original_name(name, reverse_insertion=True)

            self.model.remove_edge(layer, succ)
            self.model.add_edge(pred, succ)

        if fuse_to_pred:
            for name in layer.original_names:
                pred.add_original_name(name)
        layers_to_remove.append(layer)

    def add_preds(self, layer, preds, update_input_shapes=True):
        for pred in preds:
            self.model.add_edge(pred, layer)
            layer.append_input_layer(pred.name)
            layer.append_input_index(pred.index)
            if update_input_shapes:
                layer.append_input_shape(pred.output_shape)

    def replace_pred(self, layer, old_pred, new_pred):
        self.model.remove_edge(old_pred, layer)
        self.model.add_edge(new_pred, layer)
        layer.replace_input_layer(old_pred.name, new_pred.name)
        layer.replace_input_index(old_pred.index, new_pred.index)
        layer.replace_input_shape(new_pred.name, new_pred.output_shape)

    def remove_pred(self, layer, pred, update_input_shapes=True):
        if (pred, layer) in self.model.edges:
            self.model.remove_edge(pred, layer)
        cur_idx = layer.inputs.index(pred.name)
        layer.inputs.remove(pred.name)
        layer.input_indices.remove(pred.index)
        if update_input_shapes:
            layer.input_shapes.pop(cur_idx)

    def add_succs(self, layer, succs, update_output_shapes=True):
        for succ in succs:
            layer.append_output_layer(succ.name)
            layer.append_output_index(succ.index)
        if update_output_shapes:
            layer.output_shapes = len(layer.outputs) * [layer.output_shapes[0]]

    def replace_succ(self, layer, old_succ, new_succ):
        layer.replace_output_layer(old_succ.name, new_succ.name)
        layer.replace_output_index(old_succ.index, new_succ.index)
        layer.replace_output_shape(new_succ.name, new_succ.input_shapes[0])

    def remove_succ(self, layer, succ):
        if (layer, succ) in self.model.edges:
            self.model.remove_edge(layer, succ)
        cur_idx = layer.outputs.index(succ.name)
        layer.outputs.remove(succ.name)
        layer.output_indices.remove(succ.index)
        layer.output_shapes.pop(cur_idx)

    def add_succs_and_preds(self, layer, succs, preds, update_input_shapes=True, update_output_shapes=True):
        self.add_preds(layer, preds, update_input_shapes)
        self.add_succs(layer, succs, update_output_shapes)

    def modify_preds_order(self, layer):
        """Reverse the order of the input layers of the given layer"""
        layer.inputs = layer.inputs[::-1]
        layer.input_indices = layer.input_indices[::-1]
        layer.input_shapes = layer.input_shapes[::-1]

    @staticmethod
    def update_input_repeats(layer, reduced_shape, shape, axis):
        layer.input_repeats[layer.input_shapes.index(reduced_shape)][axis - 1] = shape[axis] // reduced_shape[axis]

    def handle_new_preds_succs(
        self,
        layer_to_inputs,
        layer_to_outputs,
        update_inputs_shapes=True,
        update_outputs_shapes=True,
    ):
        for curr_layer, preds in layer_to_inputs.items():
            self.add_preds(curr_layer, preds, update_inputs_shapes)
        for curr_layer, succs in layer_to_outputs.items():
            self.add_succs(curr_layer, succs, update_outputs_shapes)
