import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp


class ReduceMaxOp(BaseAtomicOp):
    """
    This class emulates the reduce max operation
    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, groups=1, reduce_axes=None, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._groups = groups
        self._reduce_axes = reduce_axes if reduce_axes else [-1]

    def create_hw_params(self, **kwargs):
        pass

    def enforce_encoding(self):
        """
        Infers the output zp and output scale based on the inputs scales and inputs zp
        """

        if self.input_scale_is_scalar():
            self.output_scale = self.input_scale
        elif self._reduce_axes in [[1], [2], [1, 2]]:
            self.output_scale = self.input_scales[0]
        else:
            self.output_scale = tf.reduce_mean(tf.reshape(self.input_scale, (self._groups, -1)), axis=1)

        if len(tf.convert_to_tensor(self.input_zero_points[0]).shape) == 0:
            self.output_zero_point = self.input_zero_points[0]
        else:
            self.output_zero_point = tf.reduce_mean(tf.reshape(self.input_zero_points[0], (self._groups, -1)), axis=1)

    def call_bit_exact(self, inputs, **kwargs):
        return self.call_native(inputs, **kwargs)

    def call_hw_sim(self, inputs, **kwargs):
        return self.call_native(inputs, **kwargs)

    def call_native(self, inputs, **kwargs):
        inp = inputs[0]
        if self._groups > 1:
            # Change axes to be in the range in {0, 1, 2, 4}
            reduce_axes = [a % 4 + (a % 4) // 3 for a in self._reduce_axes]
            group_input = tf.reshape(inp, [-1, inp.shape[1], inp.shape[2], self._groups, inp.shape[3] // self._groups])
            reduce_max = tf.reduce_max(
                input_tensor=group_input,
                axis=reduce_axes,
                keepdims=True,
                name="reduce_max_group",
            )
            reduce_max = tf.reshape(
                reduce_max,
                [-1, reduce_max.shape[1], reduce_max.shape[2], reduce_max.shape[3] * reduce_max.shape[4]],
            )
        else:
            reduce_max = tf.reduce_max(input_tensor=inp, axis=self._reduce_axes, keepdims=True, name="reduce_max")
            output_shape = [dim if i not in self._reduce_axes else 1 for i, dim in enumerate(reduce_max.shape)]
            reduce_max = tf.reshape(reduce_max, [-1, *output_shape[1:]])
        return reduce_max

    def _compute_output_shape(self, input_shapes):
        shape = []
        if isinstance(self._reduce_axes, list):
            reduce_axes = self._reduce_axes
        else:
            reduce_axes = [self._reduce_axes]
        num_axes = len(input_shapes)
        for axis, val in enumerate(input_shapes):
            if axis in reduce_axes or axis - num_axes in reduce_axes:
                shape.append(self._groups)
            else:
                shape.append(val)
        return shape

    def create_weight_quant_element(self, **kwargs):
        pass

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True
