from hailo_model_optimization.acceleras.atomic_ops.nms_op import NMSOp
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_single_atomic import BaseHailoSingleAtomic
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    EquivClassification,
    LayerHandlerType,
    LayerType,
    PrecisionMode,
)
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams

NMS_FORCED_BITS = 12


# TODO Need to check this export hw params
class HailoNMS(BaseHailoSingleAtomic):
    """
    Single op layer of NMSOp
    Args:
        Same as NMSOp
    Examples:
        Examples of use
        >>> op = HailoNMS(scores_threshold= 0.1,
                    iou_threshold= 0.6,
                    max_output_size= 40,
                    classes= 10,
                    input_division_factor= 1)
    """

    _hn_type = LayerType.NMS
    OP_NAME = "nms_op"

    def __init__(
        self,
        name: str,
        scores_threshold: float,
        iou_threshold: float,
        max_output_size: int,
        classes: int,
        input_division_factor: int = 1,
        logger=None,
        **kwargs,
    ):
        core_op = NMSOp(
            f"{name}/{self.OP_NAME}",
            scores_threshold=scores_threshold,
            iou_threshold=iou_threshold,
            max_output_size=max_output_size,
            classes=classes,
            input_division_factor=input_division_factor,
            fully_native=True,
            logger=logger,
        )
        super().__init__(name=name, core_op=core_op, logger=logger, **kwargs)

    def import_weights(self, layer_params: LayerParams):
        param_dict = dict(layer_params)
        self.atomic_op.import_weights(**param_dict)

    @classmethod
    def from_hn(cls, lname, hn_element, logger=None):
        layer = cls(
            name=lname,
            scores_threshold=hn_element["params"]["scores_threshold"],
            iou_threshold=hn_element["params"]["iou_threshold"],
            max_output_size=hn_element["params"]["max_output_size"],
            classes=hn_element["params"]["classes"],
            input_division_factor=hn_element["params"]["input_division_factor"],
            logger=logger,
        )
        layer.finalize_from_hn(hn_element)
        return layer

    def get_equalization_handler_type(self, predecessor_index=None):
        return EquivClassification(LayerHandlerType.unsupported, is_source=False)

    def _export_weights(self):
        weights_dict = dict()
        return weights_dict

    def _verify_and_set_hn_io_shapes(self):
        # NOTE: acceleras floating point shape doesn't match numeric shape:
        # hw boxes output is represented by 64 bits: 16 for score and 12 for each coordinate,
        # meaning the hw outputs 8 values that represent 5 acceleras values.
        return

    def enable_lossy(self, native_act=False, **kwargs):
        pass

    def disable_lossy(self, native_act=False, **kwargs):
        pass

    @property
    def fully_native(self):
        return self.atomic_op.fully_native

    @fully_native.setter
    def fully_native(self, value):
        pass

    def _get_hn_output_shapes(self):
        # acceleras nms outputs are in a different representation from the hn,
        # thus, we want to extract the hn shapes.
        return self._hn_element["output_shapes"]

    def get_precision_mode(self):
        return PrecisionMode.a16_w16_a16

    @classmethod
    def get_default_precision_mode(cls):
        return PrecisionMode.a16_w16_a16

    def is_differentiable(self) -> bool:
        return False

    def is_jit_compile_supported(self, training=False):
        """The way nms_op is implemented is not supported by the jit_compile

        It looks like the issue is related to the logic in map_fn logic inside nms_op (tf.image.non_max_supression + padding)
        """
        return False
