import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp


class ArgMaxOp(BaseAtomicOp):
    """
    This class emulates the reduce max operation
    """

    num_inputs = 1
    num_outputs = 1

    def __init__(self, name: str, reverse_order=False, logger=None, fully_native=None, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._reverse_order = reverse_order

    def enforce_encoding(self):
        """
        Infers the output zp and output scale based on the inputs scales and inputs zp
        """
        self.output_scale = np.array([1.0], np.float32)
        self.output_zero_point = np.array(0, np.float32)

    def is_differentiable(self) -> bool:
        return False

    def call_native(self, inputs, **kwargs):
        inp = inputs[0]
        axis = -1
        inp = tf.reverse(inp, axis=[axis]) if self._reverse_order else inp
        argmax = tf.math.argmax(input=inp, axis=axis, name="argmax", output_type=tf.int32)
        argmax = tf.cast(tf.expand_dims(argmax, axis=axis), tf.float32)
        return argmax

    def call_hw_sim(self, inputs, **kwargs):
        return self.call_native(inputs)

    def create_weight_quant_element(self, **kwargs):
        pass

    def create_hw_params(self, **kwargs):
        pass

    def export_weights(self):
        return dict()

    def define_constraints(self, enc):
        super().define_constraints(enc)
        enc.identity(f"{self.full_name}/output_scale:0", [1.0])
        enc.identity(f"{self.full_name}/output_zero_point:0", 0.0)

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True
