from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer import Layer
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class EinsumLayer(Layer):
    def __init__(self):
        super().__init__()
        self._op = LayerType.einsum
        self._equation = None
        self._weights = None

    @classmethod
    def create(cls, original_name, input_vertex_order, equation, weights, output_shapes=None):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.equation = equation
        layer.weights = weights

        return layer

    @property
    def equation(self):
        return self._equation

    @equation.setter
    def equation(self, value):
        self._equation = value

    @property
    def weights(self):
        return self._weights

    @weights.setter
    def weights(self, value):
        self._weights = value

    def _calc_output_shape(self):
        return self.output_shapes[0]

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
