import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.softmax_mask_op import SoftmaxMaskOp
from hailo_model_optimization.acceleras.utils.opt_utils import bankers_round_int_shift


class SoftmaxMaskOnMacOp(SoftmaxMaskOp):
    """
    Elementwise multiplication on MAC operation.

    This operation performs elementwise multiplication on MAC (Multiply-Accumulate) units.
    It takes two inputs and produces one output.
    This op supports input zero points, and it's multiplier shift affects the MAC.

    Attributes:
        num_inputs (int): The number of input tensors.
        num_outputs (int): The number of output tensors.
    """

    num_inputs = 2
    num_outputs = 1

    def enforce_encoding(self):
        """
        Infers the output zp and output scale based on the inputs scales and inputs zp
        """
        post_mult_scale = self.input_scales[0] * self.input_scales[1]
        post_shift_scale = post_mult_scale * (2**self._multiplier_shift)
        self.output_scale = post_shift_scale
        if self.input_zero_points[1] != 0:
            raise ValueError(f"SoftmaxMask input 1 zp has to be zeros, received {self.input_zero_points[1]}")
        self.output_zero_point = self.input_zero_points[0] / (
            np.mean(self.input_scales[1]) * (2**self._multiplier_shift)
        )

    def create_hw_params(self, force_shift, *args, **kwargs):
        # The force shift value should be calculated based on reduce sum if exists & only when 8 bit...
        self._multiplier_shift = force_shift
        self.enforce_encoding()

    def call_bit_exact(self, inputs, **kwargs):
        post_mult = tf.multiply(*inputs)
        if self._multiplier_shift > 0:
            post_mult = bankers_round_int_shift(post_mult, self._multiplier_shift)
        return post_mult

    def call_hw_sim(self, inputs, **kwargs):
        post_mult = tf.multiply(*inputs)
        return post_mult / (2**self._multiplier_shift)

    def export_hw_params(self):
        return {
            "output_stage/mult_shift": np.array(self._multiplier_shift, np.uint8),
            "zero_point_in_0": np.zeros_like(self.input_zero_points[0], dtype=np.int32),
            "zero_point_in_1": np.zeros_like(self.input_zero_points[1], dtype=np.int32),
        }

    def define_constraints(self, enc):
        """
        Defines the constraints for the element-wise multiplication operation.

        Args:
            enc: An instance of the encoder used for constraint definition.
        """
        super(SoftmaxMaskOp, self).define_constraints(enc)

        # compute output_scale
        enc.mul(
            enc.dummy("post_mult_scale"),
            f"{self.full_name}/input_scale:0",
            f"{self.full_name}/input_scale:1",
            inverse=True,
        )
        enc.shift(f"{self.full_name}/output_scale:0", enc.dummy("post_mult_scale"), self._multiplier_shift)
        enc.identity(f"{self.full_name}/output_zero_point:0", 0.0)
