import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_non_arithmetic_op import BaseNonArithmeticAtomicOp


class ReshapeOp(BaseNonArithmeticAtomicOp):
    """
    Given tensor, this operation returns a new tf.Tensor that has the same values as tensor in the same order,
    except with a new shape given by shape.
    The tf.reshape does not change the order of or the total number of elements in the tensor.
    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, reshape_size: int, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self.reshape_size = reshape_size

    def call_native(self, inputs, **kwargs):
        op = tf.reshape(inputs[0], self.reshape_size)
        return op

    def _compute_output_shape(self, input_shape):
        return [input_shape[0], *self.reshape_size[1:]]

    def enforce_encoding(self, *args, **kwargs):
        out_channels = self.reshape_size[-1]
        self.output_zero_point = self.input_zero_point
        if out_channels == len(self.input_scale):
            self.output_scale = self.input_scale
        else:
            # TODO: technically we can support flat to frames, by making sure the scales are grouped
            scales, idx = tf.unique(self.input_scale)
            if len(scales) != 1:
                raise RuntimeError("Cannot enforce encoding for reshape layer with multiple input scales")
            self.output_scale = tf.repeat(scales[0], out_channels)
