import copy

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    DEFAULT_RANK2_SLICE,
    LayerHandlerType,
    LayerSupportStatus,
)
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class SliceLayer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False
    _IS_REAL_LAYER = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.base_slice
        self._height_slice = None
        self._width_slice = None
        self._features_slice = None
        self._groups = 1

    @classmethod
    def create(
        cls,
        original_name,
        input_vertex_order,
        height_slice,
        width_slice,
        features_slice,
        output_shapes=None,
        groups=1,
    ):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.set_slices_dims(height_slice, width_slice, features_slice)
        layer.groups = groups
        return layer

    def set_slices_dims(self, height_slice, width_slice, features_slice):
        self.height_slice = height_slice if len(height_slice) == 3 else [*height_slice, 1]
        self.width_slice = width_slice if len(width_slice) == 3 else [*width_slice, 1]
        self.features_slice = features_slice if len(features_slice) == 3 else [*features_slice, 1]

    @property
    def height_slice(self):
        if self._height_slice is None:
            return [0, self.input_shape[1], 1]
        return self._height_slice

    @height_slice.setter
    def height_slice(self, height_slice):
        self._height_slice = height_slice

    @property
    def width_slice(self):
        if self._width_slice is None:
            return [0, self.input_shape[2], 1]
        return self._width_slice

    @width_slice.setter
    def width_slice(self, width_slice):
        self._width_slice = width_slice

    @property
    def features_slice(self):
        if self._features_slice is None:
            return [0, self.input_shape[-1], 1]
        return self._features_slice

    @features_slice.setter
    def features_slice(self, features_slice):
        self._features_slice = features_slice

    @property
    def groups(self):
        return self._groups

    @groups.setter
    def groups(self, groups):
        self._groups = groups

    def set_input_shapes(self, input_shapes, validate=True):
        super().set_input_shapes(input_shapes, validate)
        self._update_zero_negative_slices()

    def _validate_slice_params(self):
        rank2 = len(self.input_shape) == 2
        invalid_groups = self.input_shape[-1] % self.groups != 0
        groups_size = self.input_shape[-1] / self.groups
        invalid_features = (
            self.features_slice[0] < 0
            or self.features_slice[1] > groups_size
            or self.features_slice[0] == self.features_slice[1]
        )
        invalid_rank2 = rank2 and (self.height_slice != DEFAULT_RANK2_SLICE or self.width_slice != DEFAULT_RANK2_SLICE)
        invalid_rank4 = not rank2 and (
            self.height_slice[0] < 0
            or self.width_slice[0] < 0
            or self.height_slice[1] > self.input_shape[1]
            or self.width_slice[1] > self.input_shape[2]
            or self.height_slice[0] == self.height_slice[1]
            or self.width_slice[0] == self.width_slice[1]
        )
        if invalid_groups or invalid_features or invalid_rank2 or invalid_rank4:
            raise UnsupportedModelError(f"Unsupported slice values for {self.full_name_msg}")

    def _update_zero_negative_slices(self):
        # To match kernel implementation, unsliced dims(start and step are 0) will have the
        # values: start=0 stop=dim_size, and negative ends are deducted from the dim shape.
        # negative starts are supported as well, and treated as distance from the end (non-inclusive).
        if len(self.input_shape) == 4:
            if self.height_slice:
                height_start = self.height_slice[0]
                height_stop = min(self.height_slice[1], self.input_shape[1])
                height_slice_start = self.input_shape[1] + height_start if height_start < 0 else height_start
                height_slice_stop = self.input_shape[1] + height_stop if height_stop <= 0 else height_stop
                self.height_slice = [height_slice_start, height_slice_stop, self.height_slice[2]]
            if self.width_slice:
                width_start = self.width_slice[0]
                width_stop = min(self.width_slice[1], self.input_shape[2])
                width_slice_start = self.input_shape[2] + width_start if width_start < 0 else width_start
                width_slice_stop = self.input_shape[2] + width_stop if width_stop <= 0 else width_stop
                self.width_slice = [width_slice_start, width_slice_stop, self.width_slice[2]]
        elif len(self.input_shape) == 2:
            self.height_slice = DEFAULT_RANK2_SLICE
            self.width_slice = DEFAULT_RANK2_SLICE
        if self.features_slice:
            features_start = self.features_slice[0]
            features_stop = min(self.features_slice[1], self.input_shape[-1])
            features_slice_start = self.input_shape[-1] + features_start if features_start < 0 else features_start
            features_slice_stop = self.input_shape[-1] + features_stop if features_stop <= 0 else features_stop
            self.features_slice = [features_slice_start, features_slice_stop, self.features_slice[2]]

    def _calc_output_shape(self):
        self._validate_slice_params()
        if len(self.input_shape) == 4:
            output_shape = [
                -1,
                int((self.height_slice[1] - self.height_slice[0] - 1) // self.height_slice[2] + 1),
                int((self.width_slice[1] - self.width_slice[0] - 1) // self.width_slice[2] + 1),
                int((self.features_slice[1] - self.features_slice[0] - 1) // self.features_slice[2] + 1),
            ]
        else:
            output_shape = [
                -1,
                int((self.features_slice[1] - self.features_slice[0] - 1) // self.features_slice[2] + 1),
            ]

        output_shape[-1] *= self.groups

        return output_shape

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer._height_slice = [pb.height_slice.start, pb.height_slice.end, pb.height_slice.stride]
        layer._width_slice = [pb.width_slice.start, pb.width_slice.end, pb.width_slice.stride]
        layer._features_slice = [pb.features_slice.start, pb.features_slice.end, pb.features_slice.stride]
        layer.groups = pb.groups

        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        if old_layer.op not in [LayerType.rnn, LayerType.lstm]:
            layer.height_slice = old_layer.height_slice
            layer.width_slice = old_layer.width_slice
            layer.features_slice = old_layer.features_slice
            layer.groups = old_layer.groups
        return layer

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_BASE_SLICE
        node.height_slice.start, node.height_slice.end, node.height_slice.stride = (
            self.height_slice[0],
            self.height_slice[1],
            self.height_slice[2],
        )
        node.width_slice.start, node.width_slice.end, node.width_slice.stride = (
            self.width_slice[0],
            self.width_slice[1],
            self.width_slice[2],
        )
        node.features_slice.start, node.features_slice.end, node.features_slice.stride = (
            self.features_slice[0],
            self.features_slice[1],
            self.features_slice[2],
        )
        node.groups = self.groups

        return node

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        layer.groups = hn["params"].get("groups", 1)
        layer._height_slice = hn["params"]["height_slice"]
        if len(layer._height_slice) == 2:
            layer._height_slice += [1]
        layer._width_slice = hn["params"]["width_slice"]
        if len(layer._width_slice) == 2:
            layer._width_slice += [1]
        layer._features_slice = hn["params"]["features_slice"]
        if len(layer._features_slice) == 2:
            layer._features_slice += [1]

        return layer

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["height_slice"] = [int(x) for x in self._height_slice]
        result["params"]["width_slice"] = [int(x) for x in self._width_slice]
        result["params"]["features_slice"] = [int(x) for x in self._features_slice]
        result["params"]["groups"] = self.groups

        return result

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unexpected, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unexpected
