#!/usr/bin/env python
import tensorflow as tf


def tf_meshgrid(x, y):
    """
    Create a meshgrid of two tensors
    Args:
        x (Tensor): tensor contains range of values, shape = (N)
        y (Tensor): tensor contains range of values  shape = (M)

    Returns
        tuple: tuple of two tensor with the same shape, represnt the meshgrid between (x,y)
        shape = (M x N)

    """
    with tf.compat.v1.name_scope("Meshgrid"):
        xgrid = tf.tile(x, tf.shape(input=y))
        ygrid = tf.tile(y, tf.shape(input=x))
        return xgrid, ygrid


def create_anchors(
    num_of_anchors,
    y_centers,
    x_centers,
    anchors_heights,
    anchors_heights_div_2,
    anchors_heights_minus_div_2,
    anchors_widths,
    anchors_widths_div_2,
    anchors_widths_minus_div_2,
):
    """
    Create anchors corresponding to the given branch spec

    Args:
        num_of_anchors (int): num of anchors for the given branch - a scalar
        y_centers (tf.Variable): y centers contains value for the computation of the ymin
            and ymax, shape=(H x (2 * num_of_anchors)) - i.e. (19 x 6)
        anchors_heights (tf.Variable): anchors' heights, shape (num_of_anchors,)
        anchors_heights_div_2 (tf.Variable): anchors' heights divided by two, shape (num_of_anchors,)
        anchors_heights_minus_div2 (tf.Variable): minus the anchors' heights divided by two, shape (num_of_anchors,)
        anchors_widths (tf.Variable): anchors' widths, shape (num_of_anchors,)
        anchors_widths_div_2 (tf.Variable): anchors' widths divided by two, shape (num_of_anchors,)
        anchors_widths_minus_div2 (tf.Variable): minus the widths' heights divided by two, shape (num_of_anchors,)

    Returns:
        Tensor: anchors data is organized correspondingly to the branch spatial dimension. each anchor contain
            the followin 10 values:
            [y_center_1, x_center_1, y_center_2, x_center_2, num_of_anchors
             height, height/2, -height/2,
             width, width/2, -width/2] (for each anchor)
        shape = (HxWx(num_of_anchors * 10))
        i.e. for branch with shape (19x19x30)

    """
    with tf.compat.v1.name_scope("Branch_Anchors_Generator"):
        anchors_grid_list = []
        for i in range(0, num_of_anchors * 2, 2):
            # extract x_centers and y centers for xmin, ymin computations
            x_channel_1 = tf.slice(x_centers, [0, i], [-1, 1])
            x_channel_1 = tf.transpose(a=x_channel_1, perm=[1, 0])
            y_channel_1 = tf.slice(y_centers, [0, i], [-1, 1])
            x_grid_1, y_grid_1 = tf_meshgrid(x_channel_1, y_channel_1)
            # extract x_centers and y centers for xmax, ymax computations
            x_channel_2 = tf.slice(x_centers, [0, i + 1], [-1, 1])
            x_channel_2 = tf.transpose(a=x_channel_2, perm=[1, 0])
            y_channel_2 = tf.slice(y_centers, [0, i + 1], [-1, 1])
            x_grid_2, y_grid_2 = tf_meshgrid(x_channel_2, y_channel_2)
            # extract grid shape in order to create heights, widths grid in the same shape
            feature_map_shape = tf.concat([tf.shape(input=x_grid_1), [1]], 0)
            # widths and height grid will have the same spatial dimensions as x_grid and y_grid
            # Each ontain 3 different channels for the (width,width/2,-width/2) and (height,height/2,-height/2)
            heights = tf.concat(
                [[anchors_heights[i // 2]], [anchors_heights_div_2[i // 2]], [anchors_heights_minus_div_2[i // 2]]],
                axis=0,
            )
            heights_grid = tf.tile(tf.reshape(heights, shape=[1, 1, 3]), feature_map_shape)
            widths = tf.concat(
                [[anchors_widths[i // 2]], [anchors_widths_div_2[i // 2]], [anchors_widths_minus_div_2[i // 2]]],
                axis=0,
            )
            widths_grid = tf.tile(tf.reshape(widths, shape=[1, 1, 3]), feature_map_shape)
            anchors_grid_list.append(
                tf.concat(
                    [
                        tf.expand_dims(y_grid_1, axis=2),
                        tf.expand_dims(x_grid_1, axis=2),
                        tf.expand_dims(y_grid_2, axis=2),
                        tf.expand_dims(x_grid_2, axis=2),
                        heights_grid,
                        widths_grid,
                    ],
                    axis=2,
                ),
            )
        return tf.concat(anchors_grid_list, axis=2)


def decode_branch(
    detections_centers,
    detections_w_h,
    num_of_anchors,
    y_centers,
    x_centers,
    anchors_heights,
    anchors_heights_div_2,
    anchors_heights_minus_div_2,
    anchors_widths,
    anchors_widths_div_2,
    anchors_widths_minus_div_2,
):
    """
    Decode the detections to get [ymin, xmin, ymax, xmax] representations
    The function gets all the center values, and the box width and height needed
    for the anchor creation and bbox decoding
    Args:
        detections_centers (Tensor): centers scales from box regression head shape(batch_size, H, W, 2 * num_of_anchors)
        detections_w_h (Tensor): width and height scales from box regression head shape(batch_size, H, W, 2 * num_of_anchors)
        num_of_anchors int): num of anchors for the given branch - a scalar
        y_centers (tf.Variable): y centers contains value for the computation of the ymin
            and ymax, shape=(H x (2 * num_of_anchors)) - i.e. (19 x 6)
        x_centers (tf.Variable): x centers contains value for the computation of the xmin
                    and xmax, shape=(W x (2 * num_of_anchors)) - i.e. (19 x 6)
        anchors_heights (tf.Variable): anchors' heights, shape (num_of_anchors,)
        anchors_heights_div_2 (tf.Variable): anchors' heights divided by two, shape (num_of_anchors,)
        anchors_heights_minus_div_2 (tf.Variable): minus the anchors' heights divided by two, shape (num_of_anchors,)
        anchors_widths (tf.Variable): anchors' widths, shape (num_of_anchors,)
        anchors_widths_div_2 (tf.Variable): anchors' widths divided by two, shape (num_of_anchors,)
        anchors_widths_minus_div_2 (tf.Variable): minus the widths' heights divided by two, shape (num_of_anchors,)

    Returns
        Tensor: detections in their raw form [ymin, xmin, ymax, xmax], shape(batch_size, num of proposals * batch_size, 4)

    """
    with tf.compat.v1.name_scope("Decode_branch"):
        branch_anchors = create_anchors(
            num_of_anchors,
            y_centers,
            x_centers,
            anchors_heights,
            anchors_heights_div_2,
            anchors_heights_minus_div_2,
            anchors_widths,
            anchors_widths_div_2,
            anchors_widths_minus_div_2,
        )
        branch_anchors = tf.reshape(branch_anchors, [-1, 10])
        # In case the batch size is larger than 1 we will tile the anchors batch_size number of times
        batch_size = tf.shape(input=detections_centers)[0]
        tiled_anchors = tf.tile(tf.expand_dims(branch_anchors, 0), [batch_size, 1, 1])
        tiled_anchors = tf.reshape(tiled_anchors, [-1, 10])
        # Reshape predictions to 2D
        tiled_detections_centers = tf.reshape(detections_centers, [-1, 2])
        tiled_detections_w_h = tf.reshape(detections_w_h, [-1, 2])
        tiled_predictions = tf.concat([tiled_detections_centers, tiled_detections_w_h], axis=1)
        branch_bboxes = decode_bboxes(tiled_anchors, tiled_predictions)
        return tf.reshape(
            branch_bboxes,
            [batch_size, detections_centers.shape[1], detections_centers.shape[2], num_of_anchors * 4],
        )


def decode_bboxes(anchors, detections):
    """
    The function takes detections tensor of a given branch, shape (batch_size, H, W, A*4)
    and returns the decoded bounding boxes of all the proposals
    Args:
        anchors (Tensor): anchors data is organized correspondingly to the branch spatial dimension.
        detections (Tensor): detection tensor, output from a regression head shape (batch_size, H, W, A*4)

    Return:
        Tensor: detections in their raw form [ymin, xmin, ymax, xmax]
        The detection are reshaped and will have the following shapes [ batch_size, num of proposals,4]
        num of proposals is H * w * a

    """
    with tf.compat.v1.name_scope("Decode_bboxes"):
        (
            ycenter_a_min,
            xcenter_a_min,
            ycenter_a_max,
            xcenter_a_max,
            ha,
            ha_div_2,
            ha_minus_div_2,
            wa,
            wa_div_2,
            wa_minus_div_2,
        ) = tf.unstack(anchors, axis=1)
        ty, tx, th, tw = tf.unstack(detections, axis=1)
        # Adjust anchor width and height with exp
        w_exp = tw
        h_exp = th
        # compute bbox values
        ymin = ycenter_a_min + ty * ha + h_exp * ha_minus_div_2
        xmin = xcenter_a_min + tx * wa + w_exp * wa_minus_div_2
        ymax = ycenter_a_max + ty * ha + h_exp * ha_div_2
        xmax = xcenter_a_max + tx * wa + w_exp * wa_div_2
        return tf.transpose(a=tf.stack([ymin, xmin, ymax, xmax]))
