from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer import Layer
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class EqualLayer(Layer):
    def __init__(self):
        super().__init__()
        self._op = LayerType.equal
        self._constant_input = None
        self._number_of_inputs_supported = 2

    @classmethod
    def create(cls, original_name, input_vertex_order, constant_input, output_shapes=None):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.constant_input = constant_input
        return layer

    @property
    def constant_input(self):
        return self._constant_input

    @constant_input.setter
    def constant_input(self, constant_input):
        self._constant_input = constant_input

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported

    @property
    def input_shape(self):
        in_shape0 = self._input_shapes[0]
        if len(self._input_shapes) == 1:
            return in_shape0

        in_shape1 = self._input_shapes[1]
        return list(map(max, zip(in_shape0, in_shape1)))
