import copy

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import HnStage, LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class FeatureShuffleLayer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False

    def __init__(self):
        super().__init__()
        self._op = LayerType.feature_shuffle
        self._groups = None
        self._groups_slice = []
        self._first_reshape = None
        self._second_reshape = None

    @classmethod
    def create(
        cls,
        original_name,
        input_vertex_order,
        groups,
        groups_slice=None,
        last_reshape_name=None,
        output_shapes=None,
        first_reshape=None,
        second_reshape=None,
    ):
        layer = super().create(original_name, input_vertex_order, output_shapes)
        layer.groups = groups
        layer.groups_slice = groups_slice if groups_slice is not None else []
        if last_reshape_name:
            layer.add_original_name(last_reshape_name)
        if first_reshape and second_reshape:
            layer._first_reshape = first_reshape
            layer._second_reshape = second_reshape

        return layer

    @property
    def groups(self):
        return self._groups

    @groups.setter
    def groups(self, groups):
        self._groups = groups

    @property
    def groups_slice(self):
        return self._groups_slice

    @groups_slice.setter
    def groups_slice(self, groups_slice):
        self._groups_slice = groups_slice

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_FEATURE_SHUFFLE
        node.groups = self.groups
        if self.groups_slice:
            node.groups_slice.start, node.groups_slice.end, node.groups_slice.stride = self.groups_slice
        return node

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer.groups = pb.groups
        if pb.groups_slice:
            layer.groups_slice = pb.groups_slice
        return layer

    @classmethod
    def from_layer(cls, old_layer):
        layer = super().from_layer(old_layer)
        layer.groups = old_layer.groups
        layer.groups_slice = old_layer.groups_slice
        return layer

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        if "params" in hn:
            layer.groups = hn["params"].get("groups", None)
            layer.groups_slice = hn["params"].get("groups_slice", [])
        return layer

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["groups"] = self.groups
        result["params"]["groups_slice"] = self.groups_slice
        return result

    def get_reshape_shapes(self):
        return self._first_reshape, self._second_reshape

    def validate_parsed_reshaped_features(self):
        first_reshape, second_reshape = self.get_reshape_shapes()
        if first_reshape and second_reshape:
            h, w, c = self.input_shape[1:]
            expected_first_shape = [self.groups, c // self.groups, h, w]
            expected_second_shape = [-1, h, w]
            if first_reshape[1:] != expected_first_shape or second_reshape[1:] != expected_second_shape:
                raise UnsupportedModelError(
                    f"{self.full_name_msg} implemented by an unexpected reshape and transpose "
                    f"combination. Expected first reshape was {expected_first_shape} and got "
                    f"{first_reshape}, expected second reshape was {expected_second_shape} and "
                    f"got {second_reshape}",
                )

    def _validate_input_shapes(self):
        if all(shape[-1] % self._groups == 0 for shape in self.output_shapes):
            return True
        return False

    def update_output_shapes(self, **kwargs):
        hn_stage = kwargs["hn_stage"]
        if hn_stage == HnStage.PRE_FUSED:
            self.validate_parsed_reshaped_features()
        super().update_output_shapes()

    def set_input_shapes(self, input_shapes, validate=True):
        super().set_input_shapes(input_shapes, validate=validate)
        if validate and not self._validate_input_shapes():
            raise UnsupportedModelError(
                f"{self.full_name_msg} number of input features ({self.input_shape[-1]}) is not"
                f" divisible by the number of shuffle groups ({self._groups})",
            )

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
