import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.element_wise_mult_op import ElementwiseMultOp
from hailo_model_optimization.acceleras.utils.opt_utils import bankers_round_int_shift


class ElementwiseMultOnMacOp(ElementwiseMultOp):
    """
    Elementwise multiplication on MAC operation.

    This operation performs elementwise multiplication on MAC (Multiply-Accumulate) units.
    It takes two inputs and produces one output.
    This op supports input zero points, and it's multiplier shift affects the MAC.

    Attributes:
        num_inputs (int): The number of input tensors.
        num_outputs (int): The number of output tensors.
    """

    num_inputs = 2
    num_outputs = 1

    def get_config(self):
        """
        Returns the configuration of the operation as a dictionary.

        Returns:
            dict: Configuration of the operation.
        """
        config = super().get_config()
        config.update({"multiplier_shift": self._multiplier_shift, "input_repeats": self.input_repeats})
        return config

    @classmethod
    def from_config(cls, config):
        """
        Creates an instance of the operation from the given configuration.

        Args:
            config (dict): Configuration of the operation.

        Returns:
            ElementwiseMultOnMacOp: An instance of the operation.
        """
        valid_kwargs = {
            "name": config.pop("name"),
        }
        instance = cls(**valid_kwargs)

        for key, value in config.items():
            if key in instance.__dict__:
                setattr(instance, key, value)

        return instance

    def enforce_encoding(self):
        """
        Infers the output zp and output scale based on the inputs scales and inputs zp
        """
        input_scale_a = self.input_scale_matrix[0]
        input_scale_b = self.input_scale_matrix[1]

        post_mult_scale = input_scale_a * input_scale_b
        post_shift_scale = post_mult_scale * tf.cast((2**self._multiplier_shift), post_mult_scale.dtype)
        self.output_scale = post_shift_scale
        self.output_zero_point = tf.cast(
            -tf.multiply(*self.input_zero_points)
            / tf.cast((2**self._multiplier_shift), self.input_zero_points[0].dtype),
            self.FLOAT_TYPE_TF,
        )

    def create_hw_params(self, force_shift, *args, **kwargs):
        # The force shift value should be calculated based on reduce sum if exists & only when 8 bit...
        self._multiplier_shift = tf.convert_to_tensor(force_shift)
        self.enforce_encoding()

    def call_bit_exact(self, inputs, **kwargs):
        self.repeat_inputs(inputs)
        post_mult = (
            tf.multiply(*inputs)
            - tf.multiply(inputs[0], tf.cast(self.input_zero_points[1], inputs[1].dtype))
            - tf.multiply(inputs[1], tf.cast(self.input_zero_points[0], inputs[0].dtype))
        )
        if self._multiplier_shift > 0:
            post_mult = bankers_round_int_shift(post_mult, self._multiplier_shift)
        return post_mult

    def call_hw_sim(self, inputs, **kwargs):
        self.repeat_inputs(inputs)
        post_mult = (
            tf.multiply(*inputs)
            - tf.multiply(inputs[0], self.input_zero_points[1])
            - tf.multiply(inputs[1], self.input_zero_points[0])
        )
        return post_mult / tf.cast(2**self._multiplier_shift, post_mult.dtype)

    def export_hw_params(self):
        return {
            "output_stage/mult_shift": np.array(self._multiplier_shift, np.uint8),
            "zero_point_in_0": np.array(self.input_zero_points[0], np.int32),
            "zero_point_in_1": np.array(self.input_zero_points[1], np.int32),
        }

    def define_constraints(self, enc):
        """
        Defines the constraints for the element-wise multiplication operation.

        Args:
            enc: An instance of the encoder used for constraint definition.
        """
        super(ElementwiseMultOp, self).define_constraints(enc)

        # compute output_scale
        enc.mul(
            enc.dummy("post_mult_scale"),
            f"{self.full_name}/input_scale:0",
            f"{self.full_name}/input_scale:1",
            inverse=True,
        )
        enc.shift(f"{self.full_name}/output_scale:0", enc.dummy("post_mult_scale"), self._multiplier_shift)
        enc.identity(f"{self.full_name}/output_zero_point:0", np.float32(0.0))
