from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class TransposeLayer(LayerWithParams):
    def __init__(self):
        super().__init__()
        self._op = LayerType.transpose
        self._perm = None

    @classmethod
    def create(cls, original_name, input_vertex_order, perm, output_shapes=None):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer._perm = perm
        return layer

    @property
    def perm(self):
        return self._perm

    @perm.setter
    def perm(self, perm):
        self._perm = perm

    def update_output_shapes(self, **kwargs):
        output_shape = [self.input_shape[i] for i in self._perm]
        self.output_shapes = [output_shape[:] for _ in range(self.output_copies)]

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported

    def is_first_in_gcn_block(self, graph):
        succs = graph.successors
        preds = graph.predecessors

        transpose_preds = list(preds(self))
        if len(transpose_preds) != 1 or transpose_preds[0].op != LayerType.avgpool:
            return False

        transpose_succs = list(succs(self))
        if len(transpose_succs) != 1 or transpose_succs[0].op != LayerType.concat:
            return False
        concat = transpose_succs[0]
        concat_succs = list(succs(concat))

        if len(concat_succs) != 1 or concat_succs[0].op != LayerType.base_conv:
            return False
        conv = concat_succs[0]
        conv_succs = list(succs(conv))

        if len(conv_succs) != 1 or conv_succs[0].op != LayerType.base_activation:
            return False
        activation = conv_succs[0]
        activation_succs = list(succs(activation))

        if len(activation_succs) != 2:
            return False

        slice_succs = []
        for layer in activation_succs:
            if layer.op != LayerType.base_slice:
                return False
            slice_succs.extend(succs(layer))

        return any(layer.op == LayerType.transpose for layer in slice_succs)
