import copy

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    LayerHandlerType,
    LayerSupportStatus,
)
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import DefuseType, LayerType, PrecisionSplitMode
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class PrecisionSplitterLayer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False
    _IS_REAL_LAYER = False

    def __init__(self):
        super().__init__()
        self._op = LayerType.precision_splitter
        self._precision_split_mode = PrecisionSplitMode.NORMAL

    @property
    def precision_split_mode(self):
        return self._precision_split_mode

    @precision_split_mode.setter
    def precision_split_mode(self, precision_split_mode):
        self._precision_split_mode = precision_split_mode

    @classmethod
    def create(cls, original_name, input_vertex_order, output_shapes=None):
        return super().create(original_name, input_vertex_order, output_shapes)

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        if not cls.output_shapes:
            raise UnsupportedModelError("layer Precision Splitter requires output shapes")
        layer = super().from_pb(pb, pb_wrapper)
        layer.precision_split_mode = pb_wrapper.PRECISION_SPLIT_MODE_PB_TO_TYPE[pb.precision_split_mode]
        return layer

    @classmethod
    def from_hn(cls, hn):
        hn.setdefault("params", {})
        layer = super().from_hn(hn)
        if not cls.output_shapes:
            raise UnsupportedModelError(f"{layer.full_name_msg} requires output shapes")

        precision_split_mode = hn.get("params", {}).get("precision_split_mode", PrecisionSplitMode.NORMAL)
        layer.precision_split_mode = PrecisionSplitMode(precision_split_mode)
        return layer

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_PRECISION_SPLITTER
        node.precision_split_mode = pb_wrapper.PRECISION_SPLIT_MODE_TYPE_TO_PB[self.precision_split_mode]
        node.quantization_params.precision_mode.weights_precision_mode = (
            pb_wrapper.integrated_hw_graph_base_pb2.PROTO_WEIGHTS_PRECISION_MODE_8BIT
        )
        return node

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["precision_split_mode"] = self.precision_split_mode.value
        return result

    def update_output_shapes(self, validate_shapes=True, **kwargs):
        if validate_shapes and not self._validate_output_shapes():
            raise UnsupportedModelError(
                f"Unexpected split shapes at {self.full_name_msg}, "
                f"output_shapes={self.output_shapes}, input_shapes={self.input_shapes})",
            )
        # Overrided because len(output_shapes)>1 but output_copies == 1
        self.output_shapes = self._calc_output_shape()

    def _validate_output_shapes(self):
        if self.precision_split_mode == PrecisionSplitMode.NORMAL and len(self.output_shapes) != 2:
            return False
        expected_output_shape = copy.deepcopy(self.input_shapes[0])
        if "defuse_features" in self.defuse_params and self.defuse_type == DefuseType.normal:
            expected_output_shape[3] = self.defuse_features
        if self.precision_split_mode == PrecisionSplitMode.PIXELS:
            expected_output_shape[2] *= 2
        if any(shape != expected_output_shape for shape in self.output_shapes):
            return False
        return True

    def _calc_output_shape(self):
        output_shape = copy.deepcopy(self.input_shape)
        if "defuse_features" in self.defuse_params and self.defuse_type == DefuseType.normal:
            output_shape[3] = self.defuse_features
        if self.precision_split_mode == PrecisionSplitMode.PIXELS:
            output_shape[2] *= 2

        if self.precision_split_mode == PrecisionSplitMode.NORMAL:
            return [output_shape, output_shape]
        else:
            return [output_shape[:] for _ in range(self.output_copies)]

    def _get_output_shape(self, validate=False, layer_name=None, layer_index=None):
        if layer_name is None:
            raise UnsupportedModelError(f"{self.full_name_msg} successor name is missing, output shape is ambiguous")
        if len(self._output_indices) > 0:
            if layer_index is None:
                raise UnsupportedModelError(
                    f"{self.full_name_msg} successor index is missing, output shape is ambiguous",
                )
            return self._output_shapes[self._output_indices.index(layer_index)]
        return self._output_shapes[self.outputs.index(layer_name)]

    def sort_outputs(self):
        if self.precision_split_mode == PrecisionSplitMode.NORMAL:
            return lambda layer1, layer2: 1 if self.outputs.index(layer1.name) > self.outputs.index(layer2.name) else -1
        else:
            return super().sort_outputs()

    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        # TODO: maybe it is not unsupported
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.transparent, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported
