import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp
from hailo_model_optimization.acceleras.utils.hn_npz_utils import LayerParams


class NormalizationOp(BaseAtomicOp):
    """
    Emulate ALL normalization
        1.Layer Normalization- normalizes across all channels per data sample.
        2. Instance Normalization- normalizes across all spatial locations per channel.
        3.Group Normalization- divides the channels into groups and computes within each group the mean and variance for normalization.
        Note that if number of groups groups is equal to number of channels, then this operation becomes identical to Instance Normalization.
    args:
        rms_norm: if True, the inputs are normalized by the inverse square root mean square of all inputs. Note that here we avoid computing the mean of the input.
        reduce_axes: The , the axis that should be normalized.
        groups: The number of groups to divide the channels into. If groups=channels, this operation becomes identical to Layer Normalization.
        epsilon: A small float added to the variance to avoid dividing by zero.
    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, logger=None, rms_norm=False, reduce_axes=(3,), groups=1, **kwargs):
        super().__init__(name, logger=logger, fully_native=True, **kwargs)
        self._reduce_axes = reduce_axes
        self._rms_norm = rms_norm
        self._groups = groups
        self._epsilon = None

    @property
    def rms_norm(self):
        return self._rms_norm

    @property
    def epsilon(self):
        return self._epsilon

    @property
    def groups(self):
        return self._groups

    @property
    def reduce_axes(self):
        return self._reduce_axes

    def call_native(self, inputs, **kwargs):
        inputs_x = inputs[0]  # shape 4d (batch, height, width,  channels)
        reduce_axes = [
            a % 4 + (a % 4) // 3 for a in self.reduce_axes
        ]  # Change axes to be in the range in {0, 1, 2, 4}, [1,2,3]==>[1,2,4]

        # reashpe to 5D (batch, height, width, groups, channels/groups)
        group_input = tf.reshape(
            inputs_x, [-1, inputs_x.shape[1], inputs_x.shape[2], self.groups, inputs_x.shape[3] // self.groups]
        )

        if self.rms_norm:
            diff = group_input
        else:
            mu = tf.reduce_mean(
                input_tensor=group_input, axis=reduce_axes, keepdims=True
            )  # shape  (batch, 1, 1, groups, 1)
            diff = group_input - mu  # shape  (batch, height, width, groups, channels/groups) (broadcast)

        var = tf.reduce_mean(
            input_tensor=tf.square(diff), axis=reduce_axes, keepdims=True
        )  # shape  (batch, 1, 1, groups, 1)

        inv_sqrt = 1.0 / (tf.sqrt(var + self.epsilon))  # shape  (batch, 1, 1, groups, 1)
        res = tf.multiply(
            diff, inv_sqrt
        )  # does it do broadcast                             # shape  (batch, height, width, groups, channels/groups) (broadcast)

        # reashpe back 4D (batch, height, width, channels)
        res = tf.reshape(
            res,
            [-1, res.shape[1], res.shape[2], res.shape[3] * res.shape[4]],
        )
        return res

    def _compute_output_shape(self, input_shapes):
        return input_shapes

    def create_weight_quant_element(self, **kwargs):
        pass

    def export_weights(self):
        return {"epsilon": self.epsilon}

    def export_hw_params(self) -> dict:
        return dict()

    def call_hw_sim(self, inputs, **kwargs):
        return self.call_native(inputs, **kwargs)

    def import_weights(self, layer_params: LayerParams):
        self._epsilon = layer_params.get("epsilon", 1e-6)
