import itertools
from collections import OrderedDict

import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerEquivType, LayerHandlerType
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import LayerEquivError


class EquivLayer:
    def __init__(self, source_layer, layer, source_indices, layer_indices, input_index=0, type_of_layer=None):
        """
        source_layer - is the source predecessor of current layer of equiv_layer. (layer_nn)
        layer - the current layer. (layer_nn)
        source_indices - the indices of the source layer (the equiv_layer is connected to) that correspond
                        to the ones in the current layer.
        layer_indices - the indices of the current layer that correspond to the ones in the current layer.
        input_index - the input_index (relevant for consumers)
        type_of_layer - is a producer or a consumer (Enum: LayerEquivType)
        """
        self.source_layer = source_layer
        self.layer = layer
        self.source_indices = source_indices
        self.layer_indices = layer_indices
        self.type_of_layer = type_of_layer
        self.input_index = input_index
        self.set_indices = list()
        self._equiv_name = None
        self._following_consumers = list()
        self._prev_producers = list()

    def get_as_tuple(self):
        return self.source_layer, self.layer, self.source_indices, self.layer_indices

    def is_producer(self):
        if not self.type_of_layer:
            raise LayerEquivError(
                f"Layer equiv with{self.layer.full_name} is of type None but need to be either producer or consumer.",
            )
        return self.type_of_layer == LayerEquivType.producer

    def is_consumer(self):
        if not self.type_of_layer:
            raise LayerEquivError(
                f"Layer equiv with{self.layer.full_name} is of type None but need to be either producer or consumer.",
            )
        return self.type_of_layer == LayerEquivType.consumer

    def update_set_indices(self, set_indices):
        self.set_indices = set_indices

    def set_name(self, equiv_layer_name):
        self._equiv_name = equiv_layer_name

    def set_consumers(self, following_consumers):
        if not self.is_producer:
            raise LayerEquivError("only producers can have set_consumers")
        if len(following_consumers) > 0:
            self._following_consumers = following_consumers

    def add_producer(self, producer):
        if not self.is_consumer:
            raise LayerEquivError("only consumers can have add_producers")
        self._prev_producers.append(producer)

    @property
    def equiv_name(self):
        return self._equiv_name

    @property
    def layer_name(self):
        return self.layer.full_name

    def __eq__(self, other):
        return (
            (self.source_layer == other.source_layer)
            and (self.layer == other.layer)
            and (self.source_indices == other.source_indices)
            and (self.layer_indices == other.layer_indices)
        )

    def __hash__(self):
        return hash((self.layer, self.layer_indices))

    @property
    def following_consumers(self):
        return self._following_consumers

    @property
    def prev_producers(self):
        return self._prev_producers

    def __repr__(self):
        return f"EquivLayer({self.layer.full_name} type : {self.type_of_layer.name} n_index: {len(self.layer_indices)})"


class LayersEquivSet:
    def __init__(self, hailo_model, layer_name, equiv_set_algo):
        self._hailo_model = hailo_model
        self._hailo_model_flow = hailo_model.flow
        self._all_edges = set()
        self._consumers = list()
        self._cc_aggregators = list()
        self._transparents = list()
        self._outputs = list()
        self._unsupported = list()
        self._featurewise = list()
        self._skip = list()
        self._ew_bouncers = list()
        self._multi_source = list()
        self._activations = list()
        self._matmul = list()
        self._handled_multi_source = OrderedDict()
        self._matched_indices = dict()
        self._equiv_set_algo = equiv_set_algo
        equiv_layer = self.source_to_equiv_layer(layer_name)
        self._source_layers = [layer_name]
        self._first_source = equiv_layer
        self._update_equivalence_of_layer(equiv_layer)
        self._sources = list(map(self.source_to_equiv_layer, self._source_layers))
        self._producers = self.sources + self.ew_bouncers
        self._update_equiv_names()
        if len(self._unsupported) == 0 and len(self._skip) == 0:
            self._set_indices, self._unique_indices = self._create_set_indices()
            # add all the important things for equalization
            self.update_equalization_info()

    def update_equalization_info(self):
        self._update_consumers_producers()
        self._update_all_set_indices()
        self._concat_layers = self._get_all_concat_layers()
        self._set_indices_relu6 = self._get_set_indices_of_relu6_activation()
        self._equiv_set_flow = self._hailo_model_flow.edge_subgraph(self._all_edges)

    def _update_consumers_producers(self):
        for producer in self.producers:
            following_consumers = self.find_consumer_successors(producer.layer.full_name)
            producer.set_consumers(following_consumers)
            for consumer in following_consumers:
                consumer.add_producer(producer)

    @staticmethod
    def _create_layer_info(type_of_equiv):
        equiv_set_indices_by_type = []
        for equiv_source in type_of_equiv:
            indices_range_str = f"{equiv_source.layer_indices[0]}-{equiv_source.layer_indices[-1]}"
            flow_id_str = f"{equiv_source.layer_name}:{indices_range_str}"
            equiv_set_indices_by_type.append(flow_id_str)
        return " ,".join([x for x in equiv_set_indices_by_type])

    def equiv_set_info(self):
        return (
            f"sources: {self._create_layer_info(self.sources)} \n"
            f"concat: {self._create_layer_info(self.cc_aggregators)} \n"
            f"ew_bouncer: {self._create_layer_info(self.ew_bouncers)} \n"
            f"activations: {self._create_layer_info(self.activations)}\n"
            f"transparents: {self._create_layer_info(self.transparents)} \n"
            f"consumers: {self._create_layer_info(self.consumers)}\n"
            f"outputs: {self._create_layer_info(self.outputs)}"
        )

    def _update_equiv_names(self):
        for node_index, equiv_layer in enumerate(self.producers):
            equiv_layer.set_name(self.get_equiv_layer_name(equiv_layer, node_index))
        for node_index, equiv_layer in enumerate(self.consumers):
            equiv_layer.set_name(self.get_equiv_layer_name(equiv_layer, node_index))

    def _update_all_set_indices(self):
        for equiv_layer in itertools.chain(
            self.producers,
            self.consumers,
            self.activations,
            self.transparents,
            self.cc_aggregators,
        ):
            equiv_layer.update_set_indices(self.get_set_indices_of_equiv_layer(equiv_layer))

    def __hash__(self):
        return hash(self._first_source)

    def get_equiv_layer_name(self, equiv_layer, index):
        return f"{equiv_layer.layer.full_name}_{index}"

    def source_to_equiv_layer(self, source_name):
        source = self._hailo_model.layers[source_name]
        indices = tuple(range(source.output_shape[-1]))  # Assume the source has single output
        equiv_layer = EquivLayer(source, source, indices, indices, input_index=0, type_of_layer=LayerEquivType.producer)
        return equiv_layer

    @classmethod
    def build_layer_equiv_set(cls, hailo_model, layer_name, equiv_set_algo, handled_sources):
        layer = hailo_model.layers[layer_name]
        callback = layer.get_algo_callback(equiv_set_algo)
        if (not callback(layer_name).is_source) or (layer_name in handled_sources):
            return None
        return cls(hailo_model, layer_name, equiv_set_algo)

    @property
    def equiv_set_flow(self):
        return self._equiv_set_flow

    @property
    def source(self):
        return self._first_source

    @property
    def sources(self):
        return self._sources

    @property
    def source_layers(self):
        return self._source_layers

    @property
    def unique_indices(self):
        return self._unique_indices

    @property
    def concat_layers(self):
        return self._concat_layers

    @property
    def set_indices_relu6(self):
        return self._set_indices_relu6

    def _update_equivalence_of_layer(self, equiv_layer):
        layers_to_handle = list()
        layers_to_handle.append(equiv_layer)
        handled_layers = list()
        while layers_to_handle:
            next_layers_to_handle = list()
            for equiv_layer in layers_to_handle:
                if any(equiv_layer == layer for layer in handled_layers):
                    continue
                next_layers = self._add_successors(equiv_layer)

                handled_layers.append(equiv_layer)
                next_layers_to_handle.extend(next_layers)
            layers_to_handle = next_layers_to_handle

    def _add_successors(self, equiv_layer):
        next_layers_to_handle = list()
        successors_names = self._hailo_model_flow.successors(equiv_layer.layer.full_name)
        for succ_name in successors_names:
            input_equiv_layer = equiv_layer
            successor = self._hailo_model.layers[succ_name]
            layer_type = type(equiv_layer.layer)
            index_pred = self._hailo_model_flow.predecessors_sorted(succ_name).index(equiv_layer.layer.full_name)

            if layer_type in self._update_features_callbacks:
                input_equiv_layer = self._update_features_callbacks[layer_type](
                    self,
                    equiv_layer,
                    succ_name,
                    index_pred,
                )
                self._all_edges.add((equiv_layer.layer.full_name, succ_name))
                if input_equiv_layer is None:
                    continue

            get_algo_equiv_handler_type_callback = successor.get_algo_callback(self._equiv_set_algo)
            handler_type = get_algo_equiv_handler_type_callback(predecessor_index=index_pred).handler_type
            layers_to_handle = self._handler_type_callbacks[handler_type](
                self,
                input_equiv_layer,
                succ_name,
                index_pred,
            )
            self._all_edges.add((input_equiv_layer.layer.full_name, succ_name))
            if layers_to_handle is None:
                continue
            next_layers_to_handle.extend(layers_to_handle)
        return next_layers_to_handle

    def _update_slice_features(self, equiv_layer, succ, input_index):
        slice_indices = np.arange(*equiv_layer.layer.features_slice)
        new_source_indices, new_layer_indices = self._get_updated_slice_indices(slice_indices, equiv_layer)
        if len(new_layer_indices) == 0:
            return None

        equiv_layer = EquivLayer(
            equiv_layer.source_layer,
            equiv_layer.layer,
            tuple(new_source_indices),
            tuple(new_layer_indices),
            input_index=input_index,
        )
        return equiv_layer

    @staticmethod
    def _get_updated_slice_indices(slice_indices, equiv_layer):
        current_layer_indices = np.array(equiv_layer.layer_indices)
        current_source_indices = np.array(equiv_layer.source_indices)
        intersection = np.intersect1d(current_layer_indices, slice_indices)
        sort_idx = np.argsort(current_layer_indices)
        wanted_indices = np.searchsorted(current_layer_indices, intersection, sorter=sort_idx)
        new_source_indices = current_source_indices[wanted_indices]
        new_layer_indices = intersection - slice_indices[0]
        return new_source_indices, new_layer_indices

    def _update_feature_splitter_features(self, equiv_layer, succ, input_index):
        start, end = self._get_feature_splitter_successor_slice(equiv_layer.layer, succ)
        slice_indices = np.arange(start, end)
        new_source_indices, new_layer_indices = self._get_updated_slice_indices(slice_indices, equiv_layer)
        if len(new_layer_indices) == 0:
            return None
        equiv_layer = EquivLayer(
            equiv_layer.source_layer,
            equiv_layer.layer,
            tuple(new_source_indices),
            tuple(new_layer_indices),
            input_index=input_index,
        )
        return equiv_layer

    def _update_feature_shuffle_features(self, equiv_layer, succ, input_index):
        # we are changing layer indices and not source indices for scenarios in which a layer has more than 1 source,
        # e.g. concat. after the shuffle we need to know the absolute indices in the layer.
        new_source_indices, new_layer_indices = self._shuffled_indices(equiv_layer)
        equiv_layer = EquivLayer(
            equiv_layer.source_layer,
            equiv_layer.layer,
            new_source_indices,
            new_layer_indices,
            input_index=input_index,
        )

        return equiv_layer

    @property
    def producers(self):
        return self._producers

    @property
    def consumers(self):
        return self._consumers

    @property
    def featurewise(self):
        return self._featurewise

    @property
    def unsupported(self):
        return self._unsupported

    @property
    def skip(self):
        return self._skip

    @property
    def cc_aggregators(self):
        return self._cc_aggregators

    @property
    def outputs(self):
        return self._outputs

    @property
    def transparents(self):
        return self._transparents

    @property
    def multi_source(self):
        return self._multi_source

    @property
    def ew_bouncers(self):
        return self._ew_bouncers

    @property
    def activations(self):
        return self._activations

    @property
    def matmul(self):
        return self._matmul

    @staticmethod
    def _shuffled_indices(equiv_layer):
        new_indices = np.arange(equiv_layer.layer.input_shape[-1])
        num_groups = equiv_layer.layer.groups
        (c,) = new_indices.shape
        shuffled_indices = np.reshape(new_indices, (c // num_groups, num_groups))
        shuffled_indices = np.transpose(shuffled_indices, (1, 0))
        shuffled_indices = np.reshape(shuffled_indices, (c,))
        sort_order = shuffled_indices[np.array(equiv_layer.layer_indices)]
        idx_sort = np.argsort(sort_order)
        new_source_indices = np.array(equiv_layer.source_indices)[idx_sort]
        new_layer_indices = sort_order[idx_sort]
        return tuple(new_source_indices), tuple(new_layer_indices)

    def _get_feature_splitter_successor_slice(self, layer, successor):
        start_channel = 0
        end_channel = 0
        for out_ind, succ_name in enumerate(self._hailo_model.flow.successors_sorted(layer.full_name)):
            succ_output_shape = layer.output_shapes[out_ind]
            end_channel += succ_output_shape[-1]
            if succ_name == successor:
                break
            start_channel += succ_output_shape[-1]
        return start_channel, end_channel

    def _add_concat_layer(self, equiv_layer, succ_name, input_index):
        succ_concat = self._hailo_model.layers[succ_name]
        start_channel = self._get_concat_start_channel(succ_concat, equiv_layer.layer.full_name)
        concat_equiv_layer = EquivLayer(
            equiv_layer.source_layer,
            succ_concat,
            equiv_layer.source_indices,
            tuple(np.array(equiv_layer.layer_indices) + start_channel),
            input_index=input_index,
        )
        self._cc_aggregators.append(concat_equiv_layer)
        return [concat_equiv_layer]

    def _add_consumer_layer(self, equiv_layer, succ_name, input_index):
        succ = self._hailo_model.layers[succ_name]
        succ_equiv_layer = EquivLayer(
            equiv_layer.source_layer,
            succ,
            equiv_layer.source_indices,
            equiv_layer.layer_indices,
            input_index=input_index,
            type_of_layer=LayerEquivType.consumer,
        )
        self._consumers.append(succ_equiv_layer)

    def _add_output_layer(self, equiv_layer, succ_name, input_index):
        succ = self._hailo_model.layers[succ_name]
        succ_equiv_layer = EquivLayer(
            equiv_layer.source_layer,
            succ,
            equiv_layer.source_indices,
            equiv_layer.layer_indices,
            input_index=input_index,
        )
        self._outputs.append(succ_equiv_layer)

    def _add_unsupported_layer(self, equiv_layer, succ_name, input_index):
        succ = self._hailo_model.layers[succ_name]
        succ_equiv_layer = EquivLayer(
            equiv_layer.source_layer,
            succ,
            equiv_layer.source_indices,
            equiv_layer.layer_indices,
            input_index=input_index,
        )
        self._unsupported.append(succ_equiv_layer)

    def _add_transparent_layer(self, equiv_layer, succ_name, input_index):
        succ_transparent = self._hailo_model.layers[succ_name]
        transparent_equiv_layer = EquivLayer(
            equiv_layer.source_layer,
            succ_transparent,
            equiv_layer.source_indices,
            equiv_layer.layer_indices,
            input_index=input_index,
        )
        self._transparents.append(transparent_equiv_layer)
        return [transparent_equiv_layer]

    def _add_activation_layer(self, equiv_layer, succ_name, input_index):
        succ_activation = self._hailo_model.layers[succ_name]
        activation_equiv_layer = EquivLayer(
            equiv_layer.source_layer,
            succ_activation,
            equiv_layer.source_indices,
            equiv_layer.layer_indices,
            input_index=input_index,
        )
        self._activations.append(activation_equiv_layer)
        return [activation_equiv_layer]

    def _add_ew_bouncer_layer(self, equiv_layer, succ_name, input_index):
        succ = self._hailo_model.layers[succ_name]
        ew_bouncer_equiv_layer = EquivLayer(
            equiv_layer.source_layer,
            succ,
            equiv_layer.source_indices,
            equiv_layer.layer_indices,
            input_index=input_index,
            type_of_layer=LayerEquivType.producer,
        )
        self._ew_bouncers.append(ew_bouncer_equiv_layer)
        return [ew_bouncer_equiv_layer]

    def _add_skip_layer(self, equiv_layer, succ_name, input_index):
        succ = self._hailo_model.layers[succ_name]

        succ_equiv_layer = EquivLayer(
            equiv_layer.source_layer,
            succ,
            equiv_layer.source_indices,
            equiv_layer.layer_indices,
            input_index=input_index,
        )
        self._skip.append(succ_equiv_layer)

    def _add_featurewise_layer(self, equiv_layer, succ_name, input_index):
        succ = self._hailo_model.layers[succ_name]
        succ_equiv_layer = EquivLayer(
            equiv_layer.source_layer,
            succ,
            equiv_layer.source_indices,
            equiv_layer.layer_indices,
            input_index=input_index,
        )
        self._featurewise.append(succ_equiv_layer)
        return [succ_equiv_layer]

    def _add_multi_source_layer(self, equiv_layer, succ_name, input_index):
        succ = self._hailo_model.layers[succ_name]
        layers_name = equiv_layer.layer.full_name
        succ_equiv_layer = EquivLayer(
            equiv_layer.source_layer,
            succ,
            equiv_layer.source_indices,
            equiv_layer.layer_indices,
            input_index=input_index,
        )
        self._multi_source.append(succ_equiv_layer)
        first_encounter = succ_name not in self._handled_multi_source
        if first_encounter:
            first_input = True
        else:
            first_input = self._handled_multi_source[succ_name] == layers_name

        # Find matching indices
        multi_source_matches = self._matched_indices.get(succ_name, dict())
        current_succ_input = multi_source_matches.get(layers_name, list())
        current_succ_input.append(succ_equiv_layer)
        multi_source_matches[layers_name] = current_succ_input
        self._matched_indices[succ_name] = multi_source_matches

        if not first_encounter and not first_input:
            return None
        elif not first_encounter and first_input:
            return [succ_equiv_layer]
        elif first_encounter:
            self._handled_multi_source[succ_name] = layers_name
            layers_to_handle = [succ_equiv_layer]
            pred_names = list(self._hailo_model_flow.predecessors_sorted(succ_name))
            for pred_name in pred_names:
                layers_to_handle.extend(self._find_source_predecessors(pred_name))
            return layers_to_handle

    def _add_matmul_layer(self, equiv_layer: EquivLayer, succ_name: str, input_index: int):
        succ = self._hailo_model.layers[succ_name]
        succ_equiv_layer = EquivLayer(
            equiv_layer.source_layer,
            succ,
            equiv_layer.source_indices,
            equiv_layer.layer_indices,
            input_index=input_index,
        )
        self._matmul.append(succ_equiv_layer)
        return [succ_equiv_layer]

    def _add_matmul_transpose_layer(self, equiv_layer: EquivLayer, succ_name: str, input_index: int):
        succ = self._hailo_model.layers[succ_name]
        layers_name = equiv_layer.layer.full_name
        succ_equiv_layer = EquivLayer(
            equiv_layer.source_layer,
            succ,
            equiv_layer.source_indices,
            equiv_layer.layer_indices,
            input_index=input_index,
        )
        self._matmul.append(succ_equiv_layer)
        first_encounter = succ_name not in self._handled_multi_source

        if not first_encounter:
            return None
        elif first_encounter:
            self._handled_multi_source[succ_name] = layers_name
            layers_to_handle = []
            pred_names = list(self._hailo_model_flow.predecessors_sorted(succ_name))
            for pred_name in pred_names:
                layers_to_handle.extend(self._find_source_predecessors(pred_name))
            return layers_to_handle

    def _find_source_predecessors(self, layer_name, add_sources=True):
        """
         BFS search for layer's source predecessors

        Args:
            layer_name: layer name
            add_sources: a bolean if to add source

        Returns: predecessors of a given layer

        """
        layers_to_handle = [layer_name]
        additional_sources = list()
        while layers_to_handle:
            additional_sources.extend([lname for lname in layers_to_handle if self._is_source(lname)])
            next_layers_to_handle = list()
            for lname in filter(lambda x: not self._is_source(x), layers_to_handle):
                next_layers_to_handle.extend(self._get_predecessors(lname))
            layers_to_handle = next_layers_to_handle
        new_sources = {src for src in additional_sources if src not in self._source_layers}
        if add_sources:
            self._source_layers.extend(list(new_sources))
        new_sources_layers = [self._hailo_model.layers[source] for source in new_sources]
        equiv_additional_sources = [
            EquivLayer(source, source, tuple(range(source.output_shape[-1])), tuple(range(source.output_shape[-1])))
            for source in new_sources_layers
        ]
        return equiv_additional_sources

    def find_source_predecessors(self, layer):
        return self._find_source_predecessors(layer, False)

    def _is_source(self, lname):
        layer = self._hailo_model.layers[lname]
        callback = layer.get_algo_callback(self._equiv_set_algo)
        return callback().is_source

    def _get_predecessors(self, layer_name):
        """
        Get direct predecessors of a given layer.

        Args:
            layer_name:
        Returns:  predecessors of a given layer

        """
        layer = self._hailo_model.layers[layer_name]
        from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_add import HailoConvAdd

        if isinstance(layer, HailoConvAdd):
            return list(self._hailo_model_flow.predecessors_sorted(layer_name))[1:2]
        return list(self._hailo_model_flow.predecessors_sorted(layer_name))

    def _unexpected_layer_type(self, equiv_layer, succ_name, input_index):
        succ = self._hailo_model.layers[succ_name]
        raise LayerEquivError(f"Layer {succ_name} is of type {type(succ)}. {type(succ)} can't be a successor.")

    def _unknown_layer_type(self, equiv_layer, succ_name):
        succ = self._hailo_model.layers[succ_name]
        raise LayerEquivError(f"Layer {succ_name} is of type {type(succ)}. Unknown layer type.")

    def _not_implemented_type(self, equiv_layer, succ_name):
        succ = self._hailo_model.layers[succ_name]
        raise NotImplementedError(f"Layer {succ_name} is of type {type(succ)}, which is not supported yet")

    def _get_concat_start_channel(self, concat_layer, layer_name):
        layer_index = self._hailo_model_flow.get_edge_input_index(layer_name, concat_layer.full_name)

        channel = sum(shape[-1] for shape in concat_layer.input_shapes[:layer_index])
        return channel

    def _create_set_indices(self):
        """
        If we had multiple source layers, the source_index by itself it not enough. Hence, in this case we will create
        for each (source, source_index) a unique set_index. This function finds overlapping indices and
        creates this "universal" indices for all the source's channels

        Returns: (dict, list) tuple. The dict keys is (source_name, index) pair, and the value in in "universal" index.
                                     The list contains sets of (source_name, index) for each "universal" index
        """
        layer_matches = list()
        layers_unified_index = dict()
        # Iterate over overlap found in multiple source layers.
        for layer_name in self._matched_indices:
            layer = self._hailo_model.layers[layer_name]
            matches = self._find_all_overlap_indices(self._matched_indices[layer_name], layer.output_shape[-1])
            # For each overlap, resolve in one of 4 ways
            for source_index1, source_index2 in matches:
                source_index1 = tuple(source_index1)
                source_index2 = tuple(source_index2)
                if source_index1 == source_index2:
                    continue

                # Both indices does not exist: Create new value
                if (source_index1 not in layers_unified_index) and (source_index2 not in layers_unified_index):
                    new_index = len(layer_matches)
                    layer_matches.append({source_index1, source_index2})
                    layers_unified_index[source_index1] = new_index
                    layers_unified_index[source_index2] = new_index
                # Index 1 exists and index 2 doesn't: match indices based on index1
                elif (source_index1 in layers_unified_index) and (source_index2 not in layers_unified_index):
                    unified_index = layers_unified_index[source_index1]
                    layer_matches[unified_index].add(source_index2)
                    layers_unified_index[source_index2] = unified_index
                # Index 2 exists and index 1 doesn't: match indices based on index2
                elif (source_index1 not in layers_unified_index) and (source_index2 in layers_unified_index):
                    unified_index = layers_unified_index[source_index2]
                    layer_matches[unified_index].add(source_index1)
                    layers_unified_index[source_index1] = unified_index
                # Both indices exist: Copy index2 entries to index 1 and Delete index2 entries
                else:
                    unified_index1 = layers_unified_index[source_index1]
                    unified_index2 = layers_unified_index[source_index2]
                    if unified_index1 == unified_index2:
                        continue
                    layer_matches2 = layer_matches[unified_index2]
                    for src_ind in layer_matches2:
                        layers_unified_index[src_ind] = unified_index1
                    layer_matches[unified_index1] |= layer_matches[unified_index2]
                    layer_matches[unified_index2] = None

        # Add missing indices (indices that are not part of multiple source layer)
        for source_name in self._source_layers:
            source = self._hailo_model.layers[source_name]
            for i in range(source.output_shape[-1]):
                if (source_name, i) not in layers_unified_index:
                    layers_unified_index[source_name] = len(layer_matches)
                    layer_matches.append({(source_name, i)})

        # Create easy access dict and list for (source, index) pairs.
        layers_unified_index = dict()
        layer_matches = list(filter(None, layer_matches))  # Remove empty entries
        for i, indices in enumerate(layer_matches):
            for j in indices:
                if j in layers_unified_index:
                    raise LayerEquivError(f"Encountered same layer twice during indices resolving - {j}")
                layers_unified_index[j] = i
        return layers_unified_index, layer_matches

    def _find_all_overlap_indices(self, overlap_info_node, feature_count):
        """
        Find all overlapping indices in multiple source layers.
        :param overlap_info_node: overlap info of multiple source layer
        :param feature_count: overlapping pairs in current node
        :return: 2 dimensional array. Axis 0 overlaps, axis 1 is the overlapping pairs
        """
        input1, input2 = overlap_info_node.values()

        input1_indices = np.zeros(feature_count, dtype=object)
        for layer_in in input1:
            source_indices = [(layer_in.source_layer.full_name, i) for i in layer_in.source_indices]
            input1_indices[np.array(layer_in.layer_indices)] = source_indices

        input2_indices = np.zeros(feature_count, dtype=object)
        for layer_in in input2:
            source_indices = [(layer_in.source_layer.full_name, i) for i in layer_in.source_indices]
            input2_indices[np.array(layer_in.layer_indices)] = source_indices

        return np.array(list(zip(input1_indices, input2_indices)), dtype=object)

    def get_set_indices_of_equiv_layer(self, equiv_layer):
        return self.get_set_indices(equiv_layer.source_layer.full_name, equiv_layer.source_indices)

    def get_set_indices(self, source_layer, indices):
        source_index_pairs = map(lambda x: (source_layer, x), indices)
        return np.fromiter(map(lambda x: self._set_indices[x], source_index_pairs), dtype=np.int32)

    @staticmethod
    def handle_set_indices_conflicts(set_indices, array, callback):
        """
        each equiv_layer has (source_layer, layer, source_indices, layer_indices)

        For each layer_index in an equiv_layer there is one set_index (actually in this case even one source_index).
        The problem is that it could be that different layer_indices in this layer correspond to the same source_index.
        This function is created to treat those duplicates.

        the callback function takes items from array, based on duplicate values in set_indices.
        :param set_indices: "universal" indices for items in array
        :param array: the array in the length of the set_indices
        :param callback: callback which treats conflict in flow indices
        :return: sorted array of set_indices, with array the same size, after callback has been applied to
        duplicates
        """
        if len(set_indices) != len(array):
            raise LayerEquivError(
                f"The array is of length {len(array)} and the length of set_indices is {len(set_indices)} "
                f"but must be in the same length .",
            )
        sort_idx = np.argsort(set_indices)
        unique_set_indices, indices_indices, indices_count = np.unique(
            set_indices[sort_idx],
            return_index=True,
            return_counts=True,
        )
        layer_scale_unique = []
        for start, count in zip(indices_indices, indices_count):
            layer_scale_unique.append(callback(array[sort_idx][start : start + count]))
        return unique_set_indices, layer_scale_unique

    def find_consumer_successors(self, start_layer):
        consumers_names = self.equiv_bound_successors(start_layer, LayerHandlerType.consumer)
        return [consumer for consumer in self._consumers if consumer.layer.full_name in consumers_names]

    def equiv_bound_successors(self, start_layer, wanted_handler_type=None):
        """
        BFS iteration over the graph, used for finding successors from a given start layer.
        :param start_layer: start layer for BFS iteration
        :param wanted_handler_type: which successors the iteration should take. LayerHandlerType object.
        if None is given, returns alls.
        :return: returns list with wanted layers.
        """
        wanted_layers = set()
        layers_to_handle = [start_layer]
        while layers_to_handle:
            next_layers_to_handle = []
            for layer_name in layers_to_handle:
                successors_names = self._hailo_model_flow.successors(layer_name)
                for succ_name in successors_names:
                    successor = self._hailo_model.layers[succ_name]
                    index_pred = self._hailo_model_flow.predecessors_sorted(succ_name).index(layer_name)
                    layer_handler_type, layer_is_source = successor.get_algo_callback(self._equiv_set_algo)(
                        predecessor_index=index_pred,
                    )
                    if (wanted_handler_type is None) or (layer_handler_type == wanted_handler_type):
                        wanted_layers.add(succ_name)
                    if layer_handler_type in {
                        LayerHandlerType.featurewise,
                        LayerHandlerType.multi_source,
                        LayerHandlerType.ew_bouncer,
                        LayerHandlerType.transparent,
                        LayerHandlerType.cc_aggregator,
                    }:
                        next_layers_to_handle.append(succ_name)
                    else:
                        continue
            layers_to_handle = next_layers_to_handle
        return wanted_layers

    # These layers are not part of the callbacks because they have distinct role.
    # We want the feature indexes to change, based on the functions, no matter how we want to treat these layers.
    from hailo_model_optimization.acceleras.hailo_layers.hailo_feature_shuffle import HailoFeatureShuffle
    from hailo_model_optimization.acceleras.hailo_layers.hailo_feature_splitter import HailoFeatureSplitter
    from hailo_model_optimization.acceleras.hailo_layers.hailo_slice import HailoSlice

    _update_features_callbacks = {
        HailoFeatureSplitter: _update_feature_splitter_features,
        HailoSlice: _update_slice_features,
        HailoFeatureShuffle: _update_feature_shuffle_features,
    }

    _handler_type_callbacks = {
        LayerHandlerType.transparent: _add_transparent_layer,
        LayerHandlerType.cc_aggregator: _add_concat_layer,
        LayerHandlerType.consumer: _add_consumer_layer,
        LayerHandlerType.multi_source: _add_multi_source_layer,
        LayerHandlerType.unsupported: _add_unsupported_layer,
        LayerHandlerType.featurewise: _add_featurewise_layer,
        LayerHandlerType.output: _add_output_layer,
        LayerHandlerType.skip: _add_skip_layer,
        LayerHandlerType.ew_bouncer: _add_ew_bouncer_layer,
        LayerHandlerType.activation: _add_activation_layer,
        LayerHandlerType.matmul: _add_matmul_layer,
        LayerHandlerType.matmul_transpose: _add_matmul_transpose_layer,
        LayerHandlerType.unexpected: _unexpected_layer_type,
        LayerHandlerType.undefined: _unexpected_layer_type,
        None: _unknown_layer_type,  # Should never occur, since classifier raises an error before.
    }

    def _get_all_concat_layers(self):
        """Get all the layers this are concatenated with this equive class"""
        concats_names_values = []
        if len(self.cc_aggregators) == 0:
            return concats_names_values
        for source in self.source_layers:
            cc_aggregators = self.equiv_bound_successors(source, LayerHandlerType.cc_aggregator)
            multi_source = self.equiv_bound_successors(source, LayerHandlerType.multi_source)
            for cross_layer in itertools.chain(cc_aggregators, multi_source):
                source_preds = self.find_source_predecessors(cross_layer)
                for pred in source_preds:
                    src_name = pred.layer.full_name
                    concats_names_values.append(src_name)

        return concats_names_values

    def _get_set_indices_of_relu6_activation(self):
        self._relu6_layers = list()
        # calculate in advance all the flow indices that are connected to relu6 activation.
        all_set_indices = set()
        for equiv_layer in itertools.chain(self.producers, self.activations):
            if equiv_layer.layer.activation_atomic_op is None:
                continue
            if equiv_layer.layer.act_op.act_func != tf.nn.relu6:
                continue
            self._relu6_layers.append(equiv_layer)
            act_set_idx = set(equiv_layer.set_indices)
            all_set_indices = all_set_indices.union(act_set_idx)
        all_set_indices = list(all_set_indices)
        return all_set_indices
