import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp
from hailo_model_optimization.acceleras.lossy_elements.quant_element import QuantElement
from hailo_model_optimization.acceleras.utils.acceleras_definitions import DEFAULT_BOX_AND_OBJ_PXLS
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasValueError
from hailo_model_optimization.acceleras.utils.opt_utils import _get_limvals_numeric, limvals_to_zp_scale

BYTE_W = 8
NUM_COORD_PER_PROPOSAL = 4  # y_min, x_min, y_max, x_max
VALUES_PER_PROPOSAL = NUM_COORD_PER_PROPOSAL + 1  # coodrinates + 1 score


class NMSOp(BaseAtomicOp):
    """
    Emulates the NMS function in Hailo

    This class emulates the NMS function in Hailo. The op Assumes it gets the input
    from proposal generation layer. There the order of the axis are [Batch,class,proposals_per_output,num_boxes*5].
    Each box is 5 values. 4 box coordinates and one score.
    The op filtered the boxes according the score TH (in the HW is done in the reshaper) and the NMS operation
    Args:
        scores_threshold : Threshold to drop boxes according to the score [0:1]
        iou_threshold : IOU threshold to remove overlap boxes in the NMS [0:1]
        max_output_size : maximum number of output proposals. The layer pads with zeros in the actual number
        of proposals is less than max_output_size
        classes : Number of classes. The layer process each class separately
        input_division_factor: division of image rows to classes
        fully_native : run in fully native mode
    Examples:
        Examples of use
        >>> op = NMSOp(scores_threshold= 0.1,
                    iou_threshold= 0.6,
                    max_output_size= 40,
                    classes= 10,
                    input_division_factor= 1)
    """

    BOXES_BITS = 15
    SCORES_BITS = 15
    _proposals_per_output: int
    _num_proposals: int
    num_inputs = 1
    num_outputs = 1  # might be 2?

    def __init__(
        self,
        name: str,
        scores_threshold: float,
        iou_threshold: float,
        max_output_size: int,
        classes: int,
        input_division_factor: int = 1,
        logger=None,
        fully_native=None,
        **kwargs,
    ):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._scores_threshold = scores_threshold
        self._iou_threshold = iou_threshold
        self._max_output_size = max_output_size
        self._classes = classes
        self._input_division_factor = input_division_factor
        self.fully_native = fully_native

    def import_weights(self, proposals_per_output: int = 4, **kwargs):
        self._proposals_per_output = proposals_per_output

    def _build(self, input_shape):
        if input_shape[1] % self._input_division_factor != 0:
            raise ValueError(
                f"Dim 1 {input_shape[1]} must be divided by input_division_factor "
                f"{self._input_division_factor} at the NMS input",
            )
        if input_shape[3] % VALUES_PER_PROPOSAL != 0:
            raise ValueError(f"Dim 3 {input_shape[3]} of the nms input must be divided by {VALUES_PER_PROPOSAL}")
        if int(input_shape[1] // self._input_division_factor) != self._classes:
            raise ValueError(
                f"nms first dim {input_shape[1]} divided by division factor "
                f"{self._input_division_factor} must be {self._classes}",
            )
        self._num_proposals = self._proposals_per_output * input_shape[3] // VALUES_PER_PROPOSAL

    def call_hw_sim(self, inputs, training=False, **kwargs):
        """
        The upper call_hw_sim is copyied from tf_model.
        Curently this section will work only when _input_division_factor=1. need to modify this section to work as the
        tf_model upper section.
        TODO - https://hailotech.atlassian.net/browse/SDK-29748
        """
        inp = inputs[0]
        # translate 16 bits input
        if inp.shape[2] == self._proposals_per_output * 2:
            lsb = inp[:, :, 0::2, :]
            msb = inp[:, :, 1::2, :]
            inputs_mod = msb * (2**BYTE_W) + lsb
        else:
            inputs_mod = inp
        nms_out = []
        for div_ind in range(self._input_division_factor):
            for class_ind in range(self._classes):
                tot_ind = div_ind * self._classes + class_ind
                input_one_class = inputs_mod[:, tot_ind : tot_ind + 1, :, :]
                # flatten the input
                inputs_reshaped = tf.reshape(input_one_class, [-1, self._num_proposals, VALUES_PER_PROPOSAL])
                boxes = inputs_reshaped[:, :, 0:NUM_COORD_PER_PROPOSAL]
                scores = inputs_reshaped[:, :, -1]
                nms_out_one_class = self._core_nms_logic(boxes, scores, inputs_reshaped)
                nms_out_one_class = tf.ensure_shape(
                    nms_out_one_class,
                    [inp.shape[0], self._max_output_size, VALUES_PER_PROPOSAL],
                )

                if div_ind == 0:
                    nms_out.append(nms_out_one_class)
                else:
                    nms_out[class_ind] = tf.concat([nms_out[class_ind], nms_out_one_class], axis=1)
        # [batch, div_factor*num_proposals, 5, class] -> [batch, class, 5, div_factor*max_output_size]
        nms_out = tf.transpose(a=tf.stack(nms_out, axis=3), perm=[0, 3, 2, 1])
        nms_out = tf.reshape(
            nms_out,
            [-1, self._classes, DEFAULT_BOX_AND_OBJ_PXLS, self._input_division_factor * self._max_output_size],
        )
        return nms_out

    @tf.function
    def _core_nms_logic(self, boxes, scores, inputs_reshaped):
        # We use tf.function to make sure the map fn works in parallel during the inference.
        def _single_image_nms_fn(args):
            per_image_boxes = args[0]
            per_image_scores = args[1]
            per_image_orig_input = args[2]
            th_ind = tf.squeeze(tf.where(per_image_scores > self._scores_threshold), axis=1)
            per_image_boxes_th = tf.gather(per_image_boxes, th_ind, axis=0)
            per_image_scores_th = tf.gather(per_image_scores, th_ind, axis=0)
            per_image_orig_input_th = tf.gather(per_image_orig_input, th_ind, axis=0)
            op_ind = tf.image.non_max_suppression(
                per_image_boxes_th,
                per_image_scores_th,
                self._max_output_size,
                self._iou_threshold,
            )
            op_ind_shape = tf.shape(input=op_ind)
            op_one_im = tf.gather(per_image_orig_input_th, op_ind, axis=0)
            op_one_im_pad = tf.pad(
                tensor=op_one_im,
                paddings=([0, self._max_output_size - op_ind_shape[0]], [0, 0]),
                mode="CONSTANT",
                constant_values=0,
            )
            return tf.ensure_shape(op_one_im_pad, [self._max_output_size, VALUES_PER_PROPOSAL])

        nms_out_one_class = tf.map_fn(
            _single_image_nms_fn,
            (boxes, scores, inputs_reshaped),
            fn_output_signature=tf.float32,
            parallel_iterations=32,
        )
        nms_out_one_class_stop_grad = tf.nest.map_structure(tf.stop_gradient, nms_out_one_class)
        return nms_out_one_class_stop_grad

    def create_weight_quant_element(self, **kwargs):
        pass

    def is_differentiable(self) -> bool:
        return False

    def call_native(self, inputs, **kwargs):
        return self.call_hw_sim(inputs, **kwargs)

    def create_hw_params(self, *args, **kwargs):
        pass

    def create_input_encoding_candidates(self, input_index, input_lossy_external=None, translation_config=None):
        """
        The NMS has 2 different scales, one for the boxes and one for the scores.
        """
        input_stats = self.get_input_stats(input_index)

        proposals = int(input_stats.min.shape[0] / 5)
        scores_idx = tf.repeat([False, True], repeats=[4, 1])
        scores_idx = tf.tile(scores_idx, [proposals])

        scores_min = input_stats.min[scores_idx]
        scores_max = input_stats.max[scores_idx]

        scores_limval = _get_limvals_numeric(scores_min, scores_max, 1)
        scores_lossy_element = QuantElement(signed=False, bits=NMSOp.SCORES_BITS)
        scores_zp, scores_scales, _ = limvals_to_zp_scale(
            scores_limval, scores_lossy_element, self.full_name, self._logger
        )
        boxes_idx = tf.math.logical_not(scores_idx)
        boxes_min = input_stats.min[boxes_idx]
        boxes_max = input_stats.max[boxes_idx]
        boxes_limval = _get_limvals_numeric(boxes_min, boxes_max, 1)
        boxes_lossy_element = QuantElement(signed=False, bits=NMSOp.BOXES_BITS)
        boxes_zp, boxes_scales, _ = limvals_to_zp_scale(boxes_limval, boxes_lossy_element, self.full_name, self._logger)

        # In the future, we could have different scales for the scores and bboxes, so this function will be relevent.
        # input_scales = tf.repeat([boxes_scales, scores_scales], repeats=[4,1])
        # Currently we will just match the scales

        nms_scale = tf.math.maximum(boxes_scales, scores_scales)
        input_scales = tf.repeat([nms_scale, nms_scale], repeats=[4, 1])

        input_scales = tf.tile(input_scales, [proposals])
        self.input_scales[0] = input_scales

        if (boxes_zp != 0) or (scores_zp != 0):
            raise AccelerasValueError(
                f"boxes_zp = {boxes_zp}, scores_zp = {scores_zp}. Proposal generator zp must be zero.",
            )
        self.input_zero_points[0] = 0

    def create_output_encoding_candidates(
        self,
        output_index,
        force_range=None,
        output_lossy_external=None,
        translation_config=None,
        split_precision_zp=None,
    ):
        self._match_in_and_out_qp()

    def enforce_encoding(self, *args, **kwargs):
        self._match_in_and_out_qp()

    def export_quant_weights(self) -> dict:
        return {
            "input_division_factor": self._input_division_factor,
        }

    def export_hw_params(self):
        return {
            "input_division_factor": np.array(self._input_division_factor, np.uint8),
        }

    def _match_in_and_out_qp(self):
        """
        The output scales are affected only from the input sacles. Eventhogh the output statistics will be different
        then the input statisctics (Since the NMS will drop some boxes), the scales have to match becouse there is no
        element that can change the scales inside the NMS.
        This function will be relevent only after HailoRT will support 2 output scales for the NMS.
        """
        boxes_scale = self.input_scales[0][0]
        score_scale = self.input_scales[0][4]

        output_channels = self.output_shape[-1]
        proposals = int(output_channels / 5)

        output_scales = tf.repeat([boxes_scale, score_scale], repeats=[4, 1])
        output_scales = tf.tile(output_scales, [proposals])
        output_scales = tf.concat(
            [output_scales, tf.tile([score_scale], [output_channels - output_scales.shape[0]])],
            axis=0,
        )
        self.output_scale = output_scales

        self.output_zero_point = np.float32(0)
