import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.bias_add_op import AddBiasOp


class AddBias3DOp(AddBiasOp):
    """
    This bias will be used for 3D op
    given that this atomic op is a novelty for 3d op
    Args:
        output_disparity : calculated output disparity
    Attributes:
        short_bias (Cout): bias vector of length output_channels.

    Notes
    * input_shape for this layer is of shape [B, H, W, (Cout x D)]
    * AddBias3DOp is set to import & export only the short_bias.
    * The short_bias is the actual bias vector for this layer and is
      of length Cout.
    * However, when adding the bias vector to the tensor of the
      previous op (usually a convolution op), we need to expand (via
      tiling) the short_bias to Cout x D.

    """

    short_bias: tf.Tensor

    def __init__(
        self,
        name,
        output_disparity,
        bias_initializer=None,
        axis=(-1,),
        trainable=True,
        is_correctable=True,
        logger=None,
        fully_native=None,
        **kwargs,
    ):
        super().__init__(
            name,
            bias_initializer=bias_initializer,
            axis=axis,
            trainable=trainable,
            is_correctable=is_correctable,
            logger=logger,
            fully_native=fully_native,
            **kwargs,
        )
        self._output_disparity = output_disparity
        self._bias = None

    def _build(self, input_shape):
        if len(input_shape) == 1:
            shape = input_shape[self.axis[0]] // self._output_disparity
        else:
            shape = [input_shape[ax] // self._output_disparity for ax in self.axis]
        self.short_bias = self.add_weight(
            shape=shape,
            trainable=self.trainable,
            initializer=self.bias_initializer,
            name="bias",
        )

    @property
    def bias(self):
        """
        A tiling for the short_bias (Cout) to length Cout x output_disparity.

        Returns
            bias (Cout x D): a tiled version of self.short_bias (Cout).

        """
        return tf.reshape(tf.tile(tf.reshape(self.short_bias, [1, -1]), [1, self._output_disparity]), [-1])

    def import_weights(self, short_bias, **kwargs):
        self.pretrained_short_bias = short_bias
        if self.built:
            self.short_bias.assign(short_bias)
        else:
            self.short_bias = tf.constant(short_bias)
        self.bias_initializer = tf.keras.initializers.Constant(short_bias)

    def export_weights(self):
        return self.short_bias.numpy()

    def export_hw_params(self):
        """Wierd stuff we do where no standars are set"""
        params = super().export_hw_params()
        encoded_short_bias = params["bias"][: self.short_bias.shape[0]]
        params["bias"] = encoded_short_bias
        if "bias_q_int8_vec_a" in params:
            params["bias_q_int8_vec_a"] = encoded_short_bias
        return params
