import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp


class SoftmaxInpOp(BaseAtomicOp):
    """
    This class is used to mask the inputs of softmax before the reduce_max layer.
    In native and in HW, this op return -inf for the masked inputs.
    """

    num_outputs = 1

    def __init__(
        self,
        name,
        num_inputs=1,
        logger=None,
        fully_native=None,
        **kwargs,
    ):
        self._num_inputs = num_inputs
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)

    @property
    def num_inputs(self) -> int:
        return self._num_inputs

    def _compute_output_shape(self, input_shapes):
        if self.num_inputs == 1:
            return input_shapes
        else:
            return input_shapes[0]

    def call_native(self, inputs, **kwargs):
        if self.num_inputs == 1:
            return inputs
        else:
            return tf.where(inputs[1] == 0, -np.inf, inputs[0])

    def _build(self, input_shape):
        if self.num_inputs == 2:
            input0 = input_shape[0]
            input1 = input_shape[1]
            if not (input0[1] == 1 or input1[1] == 1 or input0[1] == input1[1]):
                raise ValueError(
                    f"SoftmaxMask inputs height must must either be equal or 1 (for broadcast) {self.full_name}"
                )
            if not (input0[2] == 1 or input1[2] == 1 or input0[2] == input1[2]):
                raise ValueError(
                    f"SoftmaxMask inputs width must must either be equal or 1 (for broadcast) {self.full_name}"
                )
            if input0[3] != input1[3]:
                raise ValueError(f"SoftmaxMask inputs must have same feature count {self.full_name}")

    def create_weight_quant_element(self, **kwargs):
        pass

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True

    def call_bit_exact(self, inputs, **kwargs):
        return inputs

    def call_hw_sim(self, inputs, **kwargs):
        return self.call_native(inputs)

    def export_weights(self):
        return dict()

    def create_hw_params(self, *args, **kwargs):
        pass

    def enforce_encoding(self, *args, **kwargs):
        self.forward_encoding()

    def forward_encoding(self):
        self.output_scale = self.input_scales[0]
        self.output_zero_point = self.input_zero_points[0]
        if self.num_inputs == 2 and self.input_zero_points[1] != 0:
            raise ValueError(f"SoftmaxMask input 1 zp has to be zeros, received {self.input_zero_points[1]}")

    def define_constraints(self, enc):
        super().define_constraints(enc)
        enc.identity(f"{self.full_name}/output_scale:0", f"{self.full_name}/input_scale:0")
        enc.identity(f"{self.full_name}/output_zero_point:0", f"{self.full_name}/input_zero_point:0")
