import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_non_arithmetic_op import BaseNonArithmeticAtomicOp
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasValueError

PAD_AXIS = 2


class ProposalGeneratoOp(BaseNonArithmeticAtomicOp):
    """
    Emulates the proposal generator operation, which rearrange the output
    from the bbox decoder before the nms operation

    The order of the data after this layer will be [Batch, Classes, proposals_per_output, num_proposals*5]
    Each proposal is the 4 coordinates of the box and one score. At the input to this layer, each anchor has scores
    per class and one box coordinates. The layer brodcast the box coordinates for all the classes.

    Args:
        input_division_factor : divide the image row to classes. This should be used in case the image size is too
        big for the nms buffer size. For example, if the input_division_factor is 2 and number of classes is 10,
        the layer will create an effective 20 classes where the height of image is half. Note that is operation is
        lossy even in native mode

    Examples:
        Examples of use
        >>> op = ProposalGeneratoOp(input_division_factor=input_division_factor)

    """

    num_inputs = 2
    num_outputs = 1

    _proposals_per_output: int
    _height: int
    _width: int
    _width_pad: int
    _zeros_width: int
    _num_anchors: int
    _classes: int
    _total_proposals: int
    _prop_gen_outputs: int
    _values_per_proposal: int

    def __init__(
        self,
        name: str,
        input_division_factor: int = 1,
        logger=None,
        fully_native=None,
        proposals_per_output=4,
        **kwargs,
    ):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._input_division_factor = input_division_factor
        self._number_of_coordinates_per_proposal = 4
        self._proposals_per_output = int(proposals_per_output)

    def import_weights(self, proposals_per_output=4, **kwargs):
        self._proposals_per_output = int(proposals_per_output)

    def _build(self, input_shape):
        if (input_shape[0][1] != input_shape[1][1]) or (input_shape[0][2] != input_shape[1][2]):
            raise ValueError(
                f"proposal generator hieght and width must bethe same between scores and boxes in {self.full_name}",
            )
        self._height = input_shape[0][1]
        self._width = input_shape[0][2]
        self._width_pad = int(np.ceil(self._width / self._proposals_per_output) * self._proposals_per_output)
        self._zeros_width = self._width_pad - self._width
        self._num_anchors = int(input_shape[0][3] // self._number_of_coordinates_per_proposal)
        if input_shape[1][3] % self._num_anchors != 0:
            raise ValueError(
                f"scores features ({input_shape[1][3]}) must be divided by the num of "
                f"anchors ({self._num_anchors}) in op {self.full_name}",
            )
        self._classes = int(input_shape[1][3] // self._num_anchors)
        # TODO: we might lose the final rows
        self._group_size = int(self._height // self._input_division_factor)
        self._total_proposals = self._num_anchors * self._group_size * self._width_pad  # total proposals per group
        self._prop_gen_outputs = int(self._total_proposals // self._proposals_per_output)
        self._values_per_proposal = self._number_of_coordinates_per_proposal + 1

    def call_native(self, inputs, **kwargs):
        concat_proposals_out = []
        for div_ind in range(self._input_division_factor):
            # slice the rows of the group
            boxes = inputs[0][:, (div_ind * self._group_size) : ((div_ind + 1) * self._group_size), :, :]
            scores = inputs[1][:, (div_ind * self._group_size) : ((div_ind + 1) * self._group_size), :, :]
            # add padding
            boxes_padded = self._pad_proposals(boxes)
            scores_padded = self._pad_proposals(scores)
            # concat the scores with the boxes. We duplicate the boxes by the number of classes
            boxes_reshaped = tf.reshape(
                boxes_padded,
                [-1, 1, self._total_proposals, self._number_of_coordinates_per_proposal],
            )
            boxes_reshaped = tf.tile(boxes_reshaped, [1, self._classes, 1, 1])
            scores_reshaped = tf.reshape(
                scores_padded,
                [-1, self._group_size, self._width_pad, self._num_anchors, self._classes],
            )
            scores_reshaped = tf.transpose(a=scores_reshaped, perm=[0, 4, 1, 2, 3])
            scores_reshaped = tf.reshape(scores_reshaped, [-1, self._classes, self._total_proposals, 1])
            proposals = tf.concat([boxes_reshaped, scores_reshaped], 3)

            # reshape the data
            proposals = tf.reshape(
                proposals,
                [
                    -1,
                    self._group_size,
                    self._width_pad // self._proposals_per_output,
                    self._proposals_per_output,
                    self._num_anchors,
                    self._values_per_proposal,
                ],
            )

            proposals = tf.transpose(a=proposals, perm=[0, 2, 1, 4, 3, 5])
            padded_width = int(np.ceil(boxes.shape[2] / self._proposals_per_output) * self._proposals_per_output)
            anchors = int(boxes.shape[3] / self._number_of_coordinates_per_proposal)
            total_proposals = int(anchors * boxes.shape[1] * padded_width)
            prop_gen_outputs = int(total_proposals / self._proposals_per_output)

            proposals_out = tf.reshape(
                proposals,
                [
                    -1,
                    self._classes,
                    prop_gen_outputs,
                    self._number_of_coordinates_per_proposal,
                    self._values_per_proposal,
                ],
            )

            proposals_out = tf.transpose(a=proposals_out, perm=[0, 1, 3, 2, 4])
            proposals_out = tf.reshape(
                proposals_out,
                [
                    -1,
                    self._classes,
                    self._number_of_coordinates_per_proposal,
                    self._values_per_proposal * prop_gen_outputs,
                ],
            )

            concat_proposals_out.append(proposals_out)

        proposals_out = tf.concat(concat_proposals_out, 1, name="proposal_out")

        return proposals_out

    def _pad_proposals(self, tensor):
        if self._zeros_width == 0:
            return tensor
        else:
            padding_tensor = tf.constant([[0, 0], [0, 0], [0, self._zeros_width], [0, 0]])
            return tf.pad(tensor, padding_tensor)

    def enforce_encoding(self, *args, **kwargs):
        # Set the scales of the scores as score_scale and boxes as box_scale
        if tf.reshape(self.input_scales[0], [-1, 1]).shape[0] != 1:  # vector scales
            boxes_scales = self.input_scales[0][0]
            scores_scales = self.input_scales[1][0]

            proposal_scale = tf.repeat([boxes_scales, scores_scales], repeats=[4, 1], axis=0)
            self.output_scale = tf.tile(proposal_scale, [self._prop_gen_outputs])

            boxes_zp = self.input_zero_points[0]
            score_zp = self.input_zero_points[1]
            if (boxes_zp != 0) or (score_zp != 0):
                raise AccelerasValueError(
                    f"boxes_zp = {boxes_zp}, score_zp = {score_zp}. Proposal generator zp must be zero.",
                )
            self.output_zero_point = tf.convert_to_tensor(0)
        else:
            # Nativ test
            return
