from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer import Layer
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class DemuxLayer(Layer):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False
    _IS_REAL_LAYER = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.demux

    def _calc_output_shape(self):
        return self.output_shapes

    @classmethod
    def from_hn(cls, hn):
        raise UnsupportedModelError(
            "Demux layer is not allowed in HNs. Instead, use the mux_demux model script command to create this layer.",
        )

    @classmethod
    def _validate_input(cls, layer):
        orig_name = (
            None if layer._original_names is None or len(layer._original_names) < 1 else layer._original_names[-1]
        )
        layer_name = layer.name
        if orig_name is not None:
            layer_name += f" , original_name={orig_name}"

        if not layer._output_shapes:
            raise UnsupportedModelError(f"{layer.full_name_msg} must have output_shapes")

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_DEMUX
        return node

    @property
    def output_height(self):
        return self._get_shape_single_dim(self._output_shapes, 1, validate=False)

    @property
    def output_width(self):
        return self._get_shape_single_dim(self._output_shapes, 2, validate=False)

    @property
    def output_features(self):
        return self._get_shape_single_dim(self._output_shapes, 3, validate=False)

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        cls._validate_input(layer)
        return layer

    def update_output_shapes(self, **kwargs):
        # Overrided because len(output_shapes)>1 but output_copies == 1
        self.output_shapes = self._calc_output_shape()

    def _get_output_shape(self, validate=False, layer_name=None, layer_index=None):
        if layer_name is None:
            raise UnsupportedModelError(f"{self.full_name_msg} successor name is missing, output shape is ambiguous")
        if len(self._output_indices) > 0:
            if layer_index is None:
                raise UnsupportedModelError(
                    f"{self.full_name_msg} successor index is missing, output shape is ambiguous",
                )
            return self._output_shapes[self._output_indices.index(layer_index)]
        return self._output_shapes[self.outputs.index(layer_name)]

    # TODO: set value instead undefined
    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
