import copy

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer import Layer
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class FeatureSplitterLayer(Layer):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False
    _IS_REAL_LAYER = True

    def __init__(self):
        super().__init__()
        self._op = LayerType.feature_splitter
        self._split_sizes = None
        self._split_indices = []
        self._groups = 1

    @property
    def split_sizes(self):
        return self._split_sizes

    @property
    def split_indices(self):
        return self._split_indices

    @property
    def output_features(self):
        return sum(shape[3] for shape in self._output_shapes)

    @property
    def groups(self):
        return self._groups

    @split_indices.setter
    def split_indices(self, split_indices):
        self._split_indices = split_indices

    @split_sizes.setter
    def split_sizes(self, new_split_dims):
        self._split_sizes = new_split_dims

    @groups.setter
    def groups(self, groups):
        self._groups = groups

    @classmethod
    def create(cls, original_name, input_vertex_order, split_sizes, output_shapes=None, groups=1):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer._split_sizes = split_sizes
        layer.groups = groups
        return layer

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        if not layer.output_shapes:
            raise UnsupportedModelError(f"{layer.full_name_msg} requires output shapes")
        layer.groups = pb.groups

        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        if old_layer.op != LayerType.lstm:
            layer.split_indices.extend(old_layer.split_indices)
            layer.groups = old_layer.groups
        return layer

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        layer.groups = hn.get("params", {}).get("groups", 1)
        if not layer.output_shapes:
            raise UnsupportedModelError(f"{layer.full_name_msg} requires output shapes")

        return layer

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super(FeatureSplitterLayer, self).to_hn(should_get_default_params))
        result["params"] = {"groups": self.groups}
        return result

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_FEATURE_SPLITTER
        for split_index in self._split_indices:
            node.split_indices.append(split_index)
        node.groups = self.groups
        return node

    def set_input_shapes(self, input_shapes, validate=True):
        super().set_input_shapes(input_shapes, validate)
        if self._split_sizes:
            inp = self.input_shapes[0]
            output_shapes = []
            if self._split_indices and len(self._split_indices) >= len(self._split_sizes):
                for index in self._split_indices:
                    if len(inp) == 2:
                        output_shapes.append([-1, self._split_sizes[index]])
                    else:
                        output_shapes.append([-1, inp[1], inp[2], self._split_sizes[index]])
            else:
                for split_size in self._split_sizes:
                    if len(inp) == 2:
                        output_shapes.append([-1, split_size])
                    else:
                        output_shapes.append([-1, inp[1], inp[2], split_size])
            self.output_shapes = output_shapes

    def update_output_shapes(self, validate_shapes=True, **kwargs):
        if validate_shapes and not self._validate_output_shapes():
            raise UnsupportedModelError(
                f"Unexpected split shapes at {self.full_name_msg}, "
                f"output_shapes={self.output_shapes}, input_shapes={self.input_shapes})",
            )
        # Overrided because len(output_shapes)>1 but output_copies == 1
        self.output_shapes = self._calc_output_shape()

    def _validate_output_shapes(self):
        if self._split_indices:
            shape_per_index = {split_index: self.output_shapes[i] for i, split_index in enumerate(self._split_indices)}
            if sum([shape[-1] for shape in shape_per_index.values()]) != self.input_shapes[0][-1]:
                return False
        elif sum([shape[-1] for shape in self.output_shapes]) != self.input_shapes[0][-1]:
            return False
        return True

    def _calc_output_shape(self):
        output_shapes = []
        for output_shape in self.output_shapes:
            if len(output_shape) == 2:
                output_shapes.append([-1, output_shape[-1]])
            else:
                output_shapes.append([-1, self.input_shape[1], self.input_shape[2], output_shape[-1]])
        return output_shapes

    def _get_output_shape(self, validate=False, layer_name=None, layer_index=None):
        if layer_name is None:
            raise UnsupportedModelError(f"{self.full_name_msg} successor name is missing, output shape is ambiguous")
        if self.split_indices:
            return self.output_shapes[self.split_indices[self.outputs.index(layer_name)]]
        elif len(self._output_indices) > 0:
            if layer_index is None:
                raise UnsupportedModelError(
                    f"{self.full_name_msg} successor index is missing, output shape is ambiguous",
                )
            return self._output_shapes[self._output_indices.index(layer_index)]
        return self._output_shapes[self.outputs.index(layer_name)]

    def sort_outputs(self):
        return lambda layer1, layer2: 1 if self.outputs.index(layer1.name) > self.outputs.index(layer2.name) else -1

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        # TODO: https://hailotech.atlassian.net/browse/SDK-37267
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
