import numpy as np

from hailo_model_optimization.acceleras.atomic_ops.feature_permute_op import FeaturePermuteOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_single_atomic import BaseHailoSingleAtomic
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    EquivClassification,
    LayerHandlerType,
    LayerType,
)


class HailoFeatureShuffle(BaseHailoSingleAtomic):
    """
    Implements feature shuffle
    """

    _hn_type = LayerType.FEATURE_SHUFFLE
    OP_NAME = "feature_permute_op"

    def __init__(self, name: str, groups, groups_slice, logger=None, **kwargs):
        atomic_op = FeaturePermuteOp(f"{name}/{self.OP_NAME}", logger=logger)
        super().__init__(name=name, core_op=atomic_op, logger=logger, **kwargs)
        self._groups = groups
        self._groups_slice = groups_slice

    @property
    def groups(self):
        return self._groups

    @property
    def groups_slice(self):
        return self._groups_slice

    def _build(self, input_shape):
        f_in = input_shape[-1]
        group_size = f_in // self.groups

        # Creates the indices order of the shuffle
        if self.groups_slice:
            # shuffle order is created by groups. each group is shuffled separately and concatenated to the final order.
            # we generate a vector of indices for each group, extract the first group indices by the given groups_slice
            # and then concatenate the rest of the groups, using set manipulations.
            start, end, step = self.groups_slice
            group_indices_vector = np.arange(0, group_size, step)
            first_group_indices_vector = group_indices_vector[start:end:step]
            second_group_indices_vector = np.array(list(set(group_indices_vector) - set(first_group_indices_vector)))
            shuffle_order = np.concatenate(
                [
                    np.concatenate(
                        [
                            first_group_indices_vector + i * group_size,
                            second_group_indices_vector + i * group_size,
                        ]
                    )
                    for i in range(self.groups)
                ]
            )
        else:
            x = np.arange(f_in, dtype=np.int32)
            x_reshaped = np.reshape(x, [self.groups, group_size])
            x_transposed = np.transpose(a=x_reshaped, axes=[1, 0])
            shuffle_order = np.reshape(x_transposed, [-1])

        self.atomic_op.feature_order = shuffle_order

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        params = hn_element["params"]
        groups = params["groups"]
        groups_slice = params.get("groups_slice", [])
        layer = cls(name=lname, groups=groups, groups_slice=groups_slice, logger=logger)
        layer.finalize_from_hn(hn_element)
        return layer

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)
