import tensorflow as tf
from tensorflow.keras.layers import InputSpec

from hailo_model_optimization.acceleras.atomic_ops.base_non_arithmetic_op import BaseNonArithmeticAtomicOp


class ReduceSumWOKernelOp(BaseNonArithmeticAtomicOp):
    """
    This class emulates the reduce sum operation. This reduce sum was desiged for a "fused" operation after ew_mult on the MAC unit.
    It inherits from BaseNonArithmeticAtomicOp (despite being arithmetic operation) because it doesn't have any weights or complex logic.
    In case this layer will have more complex logic, the inheritance should be reconsidered.
    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, groups=1, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._groups = groups  # groups is the number of groups in the last channel
        self.input_spec = InputSpec(ndim=4)  # TODO should be list??
        self._reduce_axis = 3

    def get_config(self):
        """
        Returns the configuration of the operation as a dictionary.

        Returns:
            dict: Configuration of the operation.
        """
        config = super().get_config()
        config.update({"groups": self._groups, "input_spec": self.input_spec, "_reduce_axis": self._reduce_axis})
        return config

    @classmethod
    def from_config(cls, config):
        """
        Creates an instance of the operation from the given configuration.

        Args:
            config (dict): Configuration of the operation.

        Returns:
            ElementwiseMultOnMacOp: An instance of the operation.
        """
        valid_kwargs = {
            "name": config.pop("name"),
            "groups": config.pop("groups", 1),
        }
        instance = cls(**valid_kwargs)

        for key, value in config.items():
            if key in instance.__dict__:
                setattr(instance, key, value)

        return instance

    def enforce_encoding(self):
        """
        Infers the output zp and output scale based on the inputs scales and inputs zp
        """
        reduce_size = tf.math.reduce_prod(self.input_shape[self._reduce_axis])
        group_size = tf.cast(reduce_size / self._groups, self.FLOAT_TYPE_TF)

        if self.input_scale_is_scalar():
            self.output_scale = tf.repeat(self.input_scale, self._groups)
        else:
            self.output_scale = tf.reduce_mean(tf.reshape(self.input_scale, (self._groups, -1)), axis=1)

        if len(tf.convert_to_tensor(self.input_zero_point).shape) == 0:
            self.output_zero_point = self.input_zero_points[0] * group_size
        else:
            groups_zp = self.input_zero_point.reshape((self._groups, -1))[:, 0]
            self.output_zero_point = groups_zp * group_size

    def _compute_output_shape(self, input_shapes):
        shape = []
        num_axes = len(input_shapes)
        for axis, val in enumerate(input_shapes):
            if axis == self._reduce_axis or axis - num_axes == self._reduce_axis:
                shape.append(self._groups)
            else:
                shape.append(val)

        return shape

    def call_native(self, inputs, **kwargs):
        """
        the native call of reduce_sum will also work for reduce mean with the flowing:
        native_call:
            1.reduce_sum - factor will be self.kernel which is 1 (or -1)
            2.reduce_mean - factor will be self.kernel which is 1/n (or -1/n)

        """
        inp = inputs[0]
        group_input = tf.reshape(inp, [-1, inp.shape[1], inp.shape[2], self._groups, inp.shape[3] // self._groups])
        reduce_sum = tf.reduce_sum(input_tensor=group_input, axis=self._reduce_axis + 1, name="reduce_sum_group")
        return reduce_sum

    def call_bit_exact(self, inputs, **kwargs):
        return self.call_native(inputs, **kwargs)

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True
