from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer import Layer
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class EWMaxLayer(Layer):
    def __init__(self):
        super().__init__()
        self._op = LayerType.ew_max
        self._input_list = []
        self._number_of_inputs_supported = 2

    @property
    def input_list(self):
        return self._input_list

    @input_list.setter
    def input_list(self, input_list):
        self._input_list = input_list

    def append_to_input_list(self, inp):
        self._input_list.append(inp)

    def update_output_shapes(self, **kwargs):
        if len(self._input_list) != 2:
            raise UnsupportedModelError(f"{self.full_name_msg} expects 2 inputs but found {len(self._input_list)}")
        super().update_output_shapes(**kwargs)

    def _calc_output_shape(self):
        input0_shape = self.pred_layer_output_shape(self._input_list[0], True)
        input1_shape = self.pred_layer_output_shape(self._input_list[1], True)
        return list(map(max, zip(input0_shape, input1_shape)))

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unexpected
