import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp


class SpatialTransposeOp(BaseAtomicOp):
    """
    Describes a no-op (aka identity aka passthru aka dummy),
    useful to inject tensor quantization outside of layer context.

    Note:
      1. Quantizes to UINT8 by default ("L3 passthru" if you will), can be changed by passing other quant elements
      2. Make sure to pass the scale & zero_point and to call set_lossy() to actually quantize..

    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)

    def create_weight_quant_element(self, **kwargs):
        pass

    def call_hw_sim(self, inputs, **kwargs):
        return self.call_native(inputs, **kwargs)

    def call_native(self, inputs, **kwargs):
        inp = inputs[0]
        if len(inp.shape) == 4:
            return tf.transpose(inp, perm=[0, 2, 1, 3])
        else:
            return inp

    def _compute_output_shape(self, input_shape):
        if len(input_shape) == 4:
            shape = [input_shape[0], input_shape[2], input_shape[1], input_shape[3]]
        else:
            shape = input_shape
        return shape

    def export_weights(self):
        return dict()

    def create_hw_params(self, **kwargs):
        pass

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True
