import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import InputSpec

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    EW_MULT_MULTIPLIER_SHIFT,
    HAILO15_EW_MULT_MULTIPLIER_SHIFT,
    OptimizationTarget,
)


class ElementwiseMultOp(BaseAtomicOp):
    """
    This class emulates the ew mult operation, which uses the APU's multiplier
    The operation order:
        1. receives 2 inputs from the accumulator (in int9 scale)
        2. multiplies them (int17)
        3. shifts right 1 bit (int16) - to make sure the data fits in the apu multiplier for the activation
    """

    num_inputs = 2
    num_outputs = 1

    def __init__(
        self,
        name: str,
        input_repeats=None,
        logger=None,
        fully_native=None,
        **kwargs,
    ):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self._multiplier_shift = 0
        self.input_spec = [InputSpec(ndim=4), InputSpec(ndim=4)]
        self.input_repeats = input_repeats if input_repeats else [[1, 1, 1], [1, 1, 1]]

    def enforce_encoding(self):
        """
        Infers the output zp and output scale based on the inputs scales and inputs zp
        """
        input_scale_a = self.input_scale_matrix[0]
        input_scale_b = self.input_scale_matrix[1]
        post_mult_scale = input_scale_a * input_scale_b
        shift_val = tf.cast(
            tf.pow(
                tf.constant(2.0, dtype=self.FLOAT_TYPE_TF), tf.cast(self._multiplier_shift, dtype=self.FLOAT_TYPE_TF)
            ),
            post_mult_scale.dtype,
        )
        self.output_scale = post_mult_scale * shift_val
        if self.input_zero_points[1] != 0:
            raise ValueError(f"Elementwise Mult input 1 zp has to be zeros, received {self.input_zero_points[1]}")
        mean_scale_dtype = tf.reduce_mean(self.input_scales[1]).dtype
        self.output_zero_point = self.input_zero_points[0] / (
            tf.reduce_mean(self.input_scales[1]) * tf.cast(shift_val, dtype=mean_scale_dtype)
        )

    @property
    def input_scale_matrix(self):
        shape_out = [2, -1]
        if self.input_scale_is_scalar(0) and not self.input_scale_is_scalar(1):
            input_scale0 = tf.repeat(self.input_scales[0], len(self.input_scales[1]))
            input_scale_mat = tf.stack([input_scale0, self.input_scales[1]], axis=0)
        elif self.input_scale_is_scalar(1) and not self.input_scale_is_scalar(0):
            input_scale1 = tf.repeat(self.input_scales[1], len(self.input_scales[0]))
            input_scale_mat = tf.stack([self.input_scales[0], input_scale1], axis=0)
        elif self.input_scale_is_scalar(0) and self.input_scale_is_scalar(1):
            input_scale_mat = tf.stack([self.input_scales[0], self.input_scales[1]], axis=0)
            shape_out = [2]
        else:
            input_scales = [
                tf.repeat(input_scale, repeat[-1], axis=-1)
                for input_scale, repeat in zip(self.input_scales, self.input_repeats)
            ]
            input_scale_mat = tf.stack(input_scales, axis=0)
        return tf.reshape(tf.cast(input_scale_mat, self.FLOAT_TYPE_TF), shape_out)

    def create_hw_params(self, optimization_target, **kwargs):
        self._set_multiplier_shift(optimization_target=optimization_target)
        self.enforce_encoding()

    def repeat_inputs(self, inputs):
        for i, repeats in enumerate(self.input_repeats):
            for dim, r in enumerate(repeats):
                inputs[i] = tf.repeat(inputs[i], r, axis=dim + 1)

    def call_hw_sim(self, inputs, **kwargs):
        self.repeat_inputs(inputs)
        shift_val = tf.cast(
            tf.pow(
                tf.constant(2.0, dtype=self.FLOAT_TYPE_TF), tf.cast(self._multiplier_shift, dtype=self.FLOAT_TYPE_TF)
            ),
            inputs[0].dtype,
        )
        return tf.multiply(*inputs) / shift_val

    def call_native(self, inputs, **kwargs):
        self.repeat_inputs(inputs)
        return tf.multiply(*inputs)

    def _build(self, input_shape):
        input0 = input_shape[0]
        input1 = input_shape[1]
        input0 = [-1, *[dim * ratio for dim, ratio in zip(input0[1:], self.input_repeats[0])]]
        input1 = [-1, *[dim * ratio for dim, ratio in zip(input1[1:], self.input_repeats[1])]]

        if not (input0[1] == 1 or input1[1] == 1 or input0[1] == input1[1]):
            raise ValueError(f"EWMult inputs height must must either be equal or 1 (for broadcast) {self.full_name}")
        if not (input0[2] == 1 or input1[2] == 1 or input0[2] == input1[2]):
            raise ValueError(f"EWMult inputs width must must either be equal or 1 (for broadcast) {self.full_name}")
        if input0[3] != input1[3]:
            raise ValueError(f"EWMult inputs must have same feature count {self.full_name}")

    def _compute_output_shape(self, input_shape):
        batch = input_shape[0][0]
        return [batch, *[dim * ratio for dim, ratio in zip(input_shape[0][1:], self.input_repeats[0])]]

    def _set_multiplier_shift(self, optimization_target):
        if optimization_target in [OptimizationTarget.MERCURY, OptimizationTarget.PLUTO]:
            self._multiplier_shift = tf.convert_to_tensor(HAILO15_EW_MULT_MULTIPLIER_SHIFT)
        else:
            self._multiplier_shift = tf.convert_to_tensor(EW_MULT_MULTIPLIER_SHIFT)

    def export_independent_params(self):
        return {
            "mult_shift": np.array(self._multiplier_shift, np.float32),
        }

    def export_hw_params(self):
        return {"ew_mult_apu_shift": np.array(self._multiplier_shift, np.uint8)}

    def import_independent_params(self, params):
        self._multiplier_shift = params["mult_shift"]

    def create_weight_quant_element(self, *args):
        pass

    def define_constraints(self, enc):
        super().define_constraints(enc)

        # compute output_scale
        enc.mul(
            enc.dummy("post_mult_scale"),
            f"{self.full_name}/input_scale:0",
            f"{self.full_name}/input_scale:1",
            inverse=True,
        )
        enc.shift(f"{self.full_name}/output_scale:0", enc.dummy("post_mult_scale"), self._multiplier_shift)

        # force 0 zero_points
        enc.identity(f"{self.full_name}/input_zero_point:0", np.float32(0.0))
        enc.identity(f"{self.full_name}/input_zero_point:1", np.float32(0.0))
        enc.identity(f"{self.full_name}/output_zero_point:0", np.float32(0.0))

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True
