import copy

from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerHandlerType, LayerSupportStatus
from hailo_sdk_common.hailo_nn.exceptions import UnsupportedModelError
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.layer_with_params import LayerWithParams
from hailo_sdk_common.hailo_nn.layer_equiv_set import EquivClassification


class NMSLayer(LayerWithParams):
    _REQUIRES_NATIVE_WEIGHTS = False
    _REQUIRES_QUANTIZED_WEIGHTS = False
    _IS_REAL_LAYER = True
    # 5 scalars to describe one bounding box
    BBOX_PARAMETERS = 5
    BBOX_PER_CHUNK = 4
    BBOX_WIDTH = 4

    def __init__(self):
        super().__init__()
        self._op = LayerType.nms
        self._scores_threshold = 0
        self._iou_threshold = 1
        self._max_output_size = 4096
        self._classes = 1
        self._input_division_factor = 1

    def _calc_output_shape(self):
        return [
            -1,
            self.classes * self.input_division_factor,
            type(self).BBOX_WIDTH,
            1,
        ]

    def set_nms_params(self, scores_threshold, iou_threshold, max_output_size, classes, input_division_factor):
        self._scores_threshold = scores_threshold
        self._iou_threshold = iou_threshold
        self._max_output_size = max_output_size
        self._classes = classes
        self._input_division_factor = input_division_factor
        self.validate_nms_params()

    def _classes_per_defuse(self):
        return self.input_shape[1] / self._compilation_params.get("defuse_num_layers")

    def validate_nms_params(self):
        if self.input_division_factor < 1:
            raise UnsupportedModelError(
                f"NMS input_division_factor={self.input_division_factor} should be greater than 0",
            )
        if self.input_shape and self.input_division_factor * self.classes != self._classes_per_defuse():
            raise UnsupportedModelError(
                f"NMS has wrong input_height. input_height={self.input_shape[1]}, "
                f"input_division_factor={self.input_division_factor}, classes={self.classes}",
            )

    def to_hn(self, should_get_default_params=False):
        result = copy.deepcopy(super().to_hn(should_get_default_params))
        result["params"]["scores_threshold"] = self._scores_threshold
        result["params"]["iou_threshold"] = self._iou_threshold
        result["params"]["max_output_size"] = self._max_output_size
        result["params"]["classes"] = self._classes
        result["params"]["input_division_factor"] = self._input_division_factor

        return result

    def to_pb(self, pb_wrapper, is_multi_scope):
        node = super().to_pb(pb_wrapper, is_multi_scope)
        node.type = pb_wrapper.integrated_hw_graph_base_pb2.PROTO_NETWORK_NMS
        node.scores_threshold = self._scores_threshold
        node.iou_threshold = self._iou_threshold
        node.max_output_size = self._max_output_size
        node.classes = self._classes
        node.input_division_factor = self._input_division_factor
        return node

    @property
    def scores_threshold(self):
        return self._scores_threshold

    @property
    def iou_threshold(self):
        return self._iou_threshold

    @property
    def max_output_size(self):
        return self._max_output_size

    @property
    def input_division_factor(self):
        return self._input_division_factor

    @property
    def classes(self):
        return self._classes

    @classmethod
    def from_hn(cls, hn):
        layer = super().from_hn(hn)
        params = hn["params"]
        input_division_factor = hn["params"]["input_division_factor"] if "input_division_factor" in params else 1
        layer.set_nms_params(
            params["scores_threshold"],
            params["iou_threshold"],
            params["max_output_size"],
            params["classes"],
            input_division_factor,
        )
        return layer

    @classmethod
    def from_pb(cls, pb, pb_wrapper):
        layer = super().from_pb(pb, pb_wrapper)
        layer.set_nms_params(
            pb.scores_threshold,
            pb.iou_threshold,
            pb.max_output_size,
            pb.classes,
            pb.input_division_factor,
        )
        return layer

    # TODO: set proper value to undefined
    def get_equalization_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_params_sorter_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def get_dead_channels_removal_handler_type(self, predecessor=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def ibc_supported(self):
        return LayerSupportStatus.unsupported

    @property
    def finetune_supported(self):
        return False
