from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer import Layer
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class SpatialSplitterLayer(Layer):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False
    _IS_REAL_LAYER = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.spatial_splitter
        self._axis = 2  # default to width dimensions
        self._split_sizes = None
        self._split_indices = []

    @property
    def split_sizes(self):
        return self._split_sizes

    @split_sizes.setter
    def split_sizes(self, new_split_dims):
        self._split_sizes = new_split_dims

    @property
    def axis(self):
        return self._axis

    @axis.setter
    def axis(self, new_axis):
        self._axis = new_axis

    @property
    def split_indices(self):
        return self._split_indices

    @split_indices.setter
    def split_indices(self, split_indices):
        self._split_indices = split_indices

    @classmethod
    def create(cls, original_name, input_vertex_order, split_sizes, axis, output_shapes=None):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer._axis = axis
        layer._split_sizes = split_sizes
        return layer

    def update_output_shapes(self, validate_shapes=True, **kwargs):
        if validate_shapes and not self._validate_output_shapes():
            raise UnsupportedModelError(
                f"Unexpected split shapes at {self.full_name_msg}, "
                f"output_shapes={self.output_shapes}, input_shapes={self.input_shapes})",
            )
        # Overrided because len(output_shapes)>1 but output_copies == 1
        self.output_shapes = self._calc_output_shape()

    def _validate_output_shapes(self):
        return sum(shape[self.axis] for shape in self.output_shapes) == self.input_shapes[0][self.axis]

    def _calc_output_shape(self):
        output_shapes = []
        for curr_shape in self.output_shapes:
            output_shape = [-1, self.input_shape[1], self.input_shape[2], self.input_shape[3]]
            output_shape[self.axis] = curr_shape[self.axis]
            output_shapes.append(output_shape)
        return output_shapes

    def _get_output_shape(self, validate=False, layer_name=None, layer_index=None):
        if layer_name is None:
            raise UnsupportedModelError(f"{self.full_name_msg} successor name is missing, output shape is ambiguous")
        if len(self._output_indices) > 0:
            if layer_index is None:
                raise UnsupportedModelError(
                    f"{self.full_name_msg} successor index is missing, output shape is ambiguous",
                )
            return self._output_shapes[self._output_indices.index(layer_index)]
        return self._output_shapes[self.outputs.index(layer_name)]

    def sort_outputs(self):
        return lambda layer1, layer2: 1 if self.outputs.index(layer1.name) > self.outputs.index(layer2.name) else -1

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
