import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    EW_MULT_MULTIPLIER_SHIFT,
    HAILO15_EW_MULT_MULTIPLIER_SHIFT,
    OptimizationTarget,
)


class SoftmaxMaskOp(BaseAtomicOp):
    """
    This class is used to mask the inputs of softmax before the reduce_max layer.
    In native, this op return -inf for the masked inputs.
    In HW, this op return -zp*scale for the masked inputs.
    """

    num_inputs = 2
    num_outputs = 1

    def __init__(self, name: str, logger=None, fully_native=True, **kwargs):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._multiplier_shift = 0

    def enforce_encoding(self):
        """
        Infers the output zp and output scale based on the inputs scales and inputs zp
        """

        def mean_scale(scale):
            if tf.is_tensor(scale):
                return tf.reduce_mean(scale)
            return np.mean(scale)

        post_mult_scale = tf.cast(self.input_scales[0], self.FLOAT_TYPE_TF) * self.input_scales[1]
        self.output_scale = post_mult_scale * (2**self._multiplier_shift)
        if self.input_zero_points[1] != 0:
            raise ValueError(f"SoftmaxMask input 1 zp has to be zeros, received {self.input_zero_points[1]}")
        self.output_zero_point = self.input_zero_points[0] / (
            mean_scale(self.input_scales[1]) * (2**self._multiplier_shift)
        )

    def create_hw_params(self, optimization_target, **kwargs):
        self._set_multiplier_shift(optimization_target=optimization_target)
        self.enforce_encoding()

    def call_hw_sim(self, inputs, **kwargs):
        return tf.multiply(*inputs) / (2**self._multiplier_shift)

    def call_native(self, inputs, **kwargs):
        return tf.where(inputs[1] == 0, -np.inf, inputs[0])

    def _build(self, input_shape):
        input0 = input_shape[0]
        input1 = input_shape[1]
        if not (input0[1] == 1 or input1[1] == 1 or input0[1] == input1[1]):
            raise ValueError(
                f"SoftmaxMask inputs height must must either be equal or 1 (for broadcast) {self.full_name}"
            )
        if not (input0[2] == 1 or input1[2] == 1 or input0[2] == input1[2]):
            raise ValueError(
                f"SoftmaxMask inputs width must must either be equal or 1 (for broadcast) {self.full_name}"
            )
        if input0[3] != input1[3]:
            raise ValueError(f"SoftmaxMask inputs must have same feature count {self.full_name}")

    def _compute_output_shape(self, input_shape):
        return input_shape[0]

    def _set_multiplier_shift(self, optimization_target):
        if optimization_target in [OptimizationTarget.MERCURY, OptimizationTarget.PLUTO]:
            self._multiplier_shift = HAILO15_EW_MULT_MULTIPLIER_SHIFT
        else:
            self._multiplier_shift = EW_MULT_MULTIPLIER_SHIFT

    def export_independent_params(self):
        return {
            # TODO: import this as well
            "mult_shift": np.array(self._multiplier_shift, np.float32),
        }

    def export_hw_params(self):
        return {"ew_mult_apu_shift": np.array(self._multiplier_shift, np.uint8)}

    def import_independent_params(self, params):
        self._multiplier_shift = params["mult_shift"]

    def create_weight_quant_element(self, *args):
        pass

    def define_constraints(self, enc):
        super().define_constraints(enc)

        # compute output_scale
        enc.mul(
            enc.dummy("post_mult_scale"),
            f"{self.full_name}/input_scale:0",
            f"{self.full_name}/input_scale:1",
            inverse=True,
        )
        enc.shift(f"{self.full_name}/output_scale:0", enc.dummy("post_mult_scale"), self._multiplier_shift)

        # compute output_zero_point
        enc.identity(f"{self.full_name}/input_zero_point:1", 0.0)
        enc.div(enc.dummy("pre_shift_zp"), f"{self.full_name}/input_zero_point:0", f"{self.full_name}/input_scale:1")
        enc.shift(enc.dummy("pre_shift_zp"), f"{self.full_name}/output_zero_point:0", self._multiplier_shift)

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True
