from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.slice import SliceLayer
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class FusedSliceLayer(SliceLayer):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False
    _IS_REAL_LAYER = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.slice
        self._height_slice = None
        self._width_slice = None
        self._features_slice = None

    def _validate_slice_params(self):
        super()._validate_slice_params()
        if self.height_slice[2] != 1 or self.width_slice[2] != 1 or self.features_slice[2] != 1:
            raise UnsupportedModelError(f"Unsupported slice values for {self.full_name_msg}")

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_SLICE
        return node

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        # TODO: maybe change it a bit?
        if self._features_slice[2] > 1:
            return EquivClassification(LayerHandlerType.unsupported, is_source=False)
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
