import numpy as np
import tensorflow as tf

from hailo_model_optimization.acceleras.atomic_ops.base_atomic_op import BaseAtomicOp
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    SHIFT_CALCULATE_BUFFER,
    IgnoreHwLimitationAssertionPolicy,
)
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasUnsupportedError
from hailo_model_optimization.acceleras.utils.opt_utils import calculate_shifts


class MatmulOp(BaseAtomicOp):
    """
    This class emulates the matmul operation, which uses the APU's multiplier
    The operation order:
        1. receives 2 inputs from the accumulator (in int9 scale)
        2. multiplies them (int17)
        3. shifts right 1 bit (int16) - to make sure the data fits in the apu multiplier for the activation
    """

    num_inputs = 2
    num_outputs = 1

    BIT_EXACT_USE_HSIM = False  # TODO: currently there is a bug in this mode, so it is disabled

    def __init__(
        self,
        name: str,
        transpose_matmul_input: bool = True,
        groups=1,
        zp_comp_added: bool = False,
        input_windows: list = None,
        input_tiles: list = None,
        logger=None,
        fully_native=None,
        **kwargs,
    ):
        super().__init__(name, logger=logger, fully_native=fully_native, **kwargs)
        self.fully_native = fully_native
        self._multiplier_shift = 0
        self._transpose_matmul_input = transpose_matmul_input
        self._groups = groups
        self._zp_comp_added = zp_comp_added
        self.shift_delta = 0
        self.feed_repeat = 1
        self.online_zp_compensation = True
        self.zp_comp_rank = 0
        self._input_windows = [1, 1, 1] if input_windows is None else input_windows
        if self._input_windows[2] != 1:
            msg = "windowing over the channel dimension is not supported"
            raise ValueError(msg)

        # When matmul have input tiles, it means that each group of the input is tiled.
        # For example, if we have 6 groups, input tile in the channels is 3, and the input is [1, 2, 3, 4], then the
        # resulting input will be [1, 2, 1, 2, 1, 2, 3, 4, 3, 4, 3, 4]
        self._input_tiles = input_tiles if input_tiles is not None else [[1, 1, 1], [1, 1, 1]]

    @property
    def input_tiles(self):
        return self._input_tiles

    def enforce_encoding(self):
        """
        Infers the output zp and output scale based on the inputs scales and inputs zp
        """
        data_tile = self._input_tiles[0][2]
        weight_tile = self._input_tiles[1][2]

        if (
            len(tf.convert_to_tensor(self.input_scales[0]).shape) == 0
            and len(tf.convert_to_tensor(self.input_scales[1]).shape) == 0
        ):
            post_mult_scale = self.input_scales[0] * self.input_scales[1]
            self.output_scale = post_mult_scale * (2**self._multiplier_shift)
        elif self._zp_comp_added:
            input_a = np.array(self.input_scales[0]).reshape(
                ((self.groups // data_tile) // weight_tile, weight_tile, -1)
            )[..., 0]
            input_b = np.array(self.input_scales[1]).reshape(
                ((self.groups // weight_tile) // data_tile, data_tile, -1)
            )[..., 0]
            out_values = input_a * input_b * (2**self._multiplier_shift)

            self.output_scale = np.repeat(out_values.reshape(self.groups), self.output_scale.shape[0] / self.groups)
        else:
            post_mult_scale = self.input_scales[0][0] * self.input_scales[1][0]
            self.output_scale = np.ones(self.output_shape[-1]) * post_mult_scale * (2**self._multiplier_shift)

        self.output_zero_point = np.float32(0)
        self.input_zero_points[1] = np.float32(0)

    def _internal_matmul(self, inputs, native=False, online_zp_compensation=False):
        """
        if there is self._zp_comp_added the last channel of each group of the weights input (inputs[1])
        sums of all the other channels with negation- we need to calculate matmul(data, weights[:,:-1])+ZP*weights[:,-1]
        """
        data_mat = self._prepare_matmul_data_input(inputs[0], self.groups, self._input_windows)
        weights_mat, zp_comp = self._prepare_matmul_weights_input(
            inputs[1],
            self.groups,
            self._transpose_matmul_input,
            self._zp_comp_added,
            self._input_windows,
            online_zp_compensation=online_zp_compensation,
        )
        shift = self._multiplier_shift
        if native:
            zp_comp = tf.zeros_like(zp_comp, dtype=weights_mat.dtype)
            shift = 0
        if not native and self.bit_exact:
            matmul_result = self._calculate_matmul_hw(
                [data_mat, weights_mat],
                zp_comp,
                shift,
                hsim=self._hsim if self.BIT_EXACT_USE_HSIM else None,
                accumulator_size=self.output_lossy_element.bits,
            )
        else:
            matmul_result = self._calculate_matmul([data_mat, weights_mat], zp_comp) / (2**shift)
        res = self._prepare_matmul_output(inputs[0], matmul_result, self.groups, self._input_windows)
        return res

    def call_native(self, inputs, **kwargs):
        return self._internal_matmul(inputs, native=True)

    def call_hw_sim(self, inputs, **kwargs):
        return self._internal_matmul(inputs, online_zp_compensation=self.online_zp_compensation)

    def _build(self, input_shape):
        input0_shape = input_shape[0]
        input1_shape = input_shape[1]
        data_tile = self._input_tiles[0][2]
        weight_tile = self._input_tiles[1][2]
        window_height, window_width, _ = self._input_windows
        if any(tile != 1 for tile in self._input_tiles[0][:-1] + self._input_tiles[1][:-1]) or (
            data_tile != 1 and weight_tile != 1
        ):
            raise AccelerasUnsupportedError(
                f"Unexpected input tiles at {self.full_name}, input_tiles={self._input_tiles}. Only feature tile for a single input is supported.",
            )
        if self.groups % data_tile != 0 or self.groups % weight_tile != 0:
            raise AccelerasUnsupportedError(
                f"Unexpected groups at {self.full_name}, groups={self.groups}, input_tiles={self._input_tiles}",
            )
        if input0_shape[3] % (self.groups // data_tile) != 0 or input1_shape[3] % (self.groups // weight_tile) != 0:
            raise AccelerasUnsupportedError(
                f"Unexpected input shapes at {self.full_name}, input_shapes={input_shape} (type={type(input_shape)})",
            )
        if (
            input0_shape[1] % window_height != 0
            or input0_shape[2] % window_width != 0
            or input1_shape[1] % window_height != 0
            or input1_shape[2] % window_width != 0
        ):
            raise AccelerasUnsupportedError(
                f"Unexpected input shapes at {self.full_name}, input_shapes={input_shape} (type={type(input_shape)})",
            )
        if self._transpose_matmul_input:
            if input0_shape[3] // (self.groups // data_tile) != (
                input1_shape[3] // (self.groups // weight_tile) - self.zp_comp_rank
            ):
                raise AccelerasUnsupportedError(
                    f"Unexpected input shapes at {self.full_name}, input_shapes={input_shape} (type={type(input_shape)})",
                )
        else:
            if input0_shape[3] // (self.groups // data_tile) != input1_shape[1] * input1_shape[2] // (
                window_height * window_width
            ):
                raise AccelerasUnsupportedError(
                    f"Unexpected input shapes at {self.full_name}, input_shapes={input_shape} (type={type(input_shape)})",
                )

    def create_hw_params(self, preact_limvals, hw_shifts=None, **kwargs):
        """
        calculate the output scale
        set shift according to the calibrated scale
        """
        # TODO: verify all the scales are the same
        if np.any(self.input_zero_points[0] != 0) and not self._transpose_matmul_input:
            raise AccelerasUnsupportedError(
                f"{self.full_name}: Matmul without transpose input and with negative range is not supported",
            )
        data_tile = self._input_tiles[0][2]
        weight_tile = self._input_tiles[1][2]

        if self._zp_comp_added:
            group_scales = np.reshape(self.input_scales[1], ((self.groups // weight_tile) // data_tile, data_tile, -1))
            signed_scales = group_scales[..., : -1 * self.zp_comp_rank]
            zp_comp_scales = group_scales[..., -1 * self.zp_comp_rank :].reshape(-1, self.zp_comp_rank)
            if self.zp_comp_rank == 2:
                # TODO let the feed repeat to be based on scales.
                self.feed_repeat = (zp_comp_scales.T / zp_comp_scales.min(axis=1)).T[0]
        else:
            signed_scales = self.input_scales[1]

        matmul_scale = (
            np.reshape(self.input_scales[0], ((self.groups // data_tile) // weight_tile, weight_tile, -1))[..., 0]
            * np.reshape(signed_scales, ((self.groups // weight_tile) // data_tile, data_tile, -1))[..., 0]
        ).reshape(self.groups)
        limval_max = preact_limvals[1]
        limval_min = preact_limvals[0]
        limvals = np.maximum(np.abs(limval_max), np.abs(limval_min))
        expected_max_output = np.max(limvals / matmul_scale)

        accumultor_size = self.output_lossy_element.bits  # get accumulatorgb

        needed_shift, shift_delta = calculate_shifts(
            expected_max_output, accumultor_size, SHIFT_CALCULATE_BUFFER, hw_shifts=hw_shifts
        )
        if shift_delta != 0:
            name_to_display = "/".join(self.full_name.split("/")[:-1])
            if self._ignore_hw_limitation_assertion != IgnoreHwLimitationAssertionPolicy.enabled:
                factor = self._trunc_plus(2**shift_delta, 3)
                factor_sqrt = np.sqrt(factor)
                range_min0, range_max0 = self.get_input_limvals(0)
                range_min0, range_max0 = self._trunc_plus(range_min0, 3), self._trunc_plus(range_max0, 3)
                range0_str = f"[{range_min0:.03f}, {range_max0:.03f}]"
                range_min1, range_max1 = self.get_input_limvals(1)
                range_min1, range_max1 = self._trunc_plus(range_min1, 3), self._trunc_plus(range_max1, 3)
                range1_str = f"[{range_min1:.03f}, {range_max1:.03f}]"
                range0_fix_str = f"[{range_min0*factor_sqrt:.03f}, {range_max0*factor_sqrt:.03f}]"
                range1_fix_str = f"[{range_min1*factor_sqrt:.03f}, {range_max1*factor_sqrt:.03f}]"
                raise AccelerasUnsupportedError(
                    f"layer {name_to_display} does not support shift delta. To overcome this issue you should "
                    f"force larger range at the inputs of the layer using command "
                    f"quantization_param([layer_name], force_range_in=[range_min, range_max], force_range_index=index) "
                    f"current range of input 0 is {range0_str} and input 1 is {range1_str}."
                    f"You should increase the multiplication of these ranges by a factor of {factor:.03f}, "
                    f"e.g. you can apply factor of sqrt({factor:.03f}) to both inputs:\n"
                    f"quantization_param([{name_to_display}], force_range_in={range0_fix_str}, force_range_index=0)\n"
                    f"quantization_param([{name_to_display}], force_range_in={range1_fix_str}, force_range_index=1)\n",
                )
        self._multiplier_shift = needed_shift
        self.shift_delta = 0
        self.enforce_encoding()

    @staticmethod
    def _trunc_plus(value, decimals=0):
        """
        Truncate a float to a certain number of decimal places, and round up the last digit.
        """
        return (np.trunc(value * 10**decimals) + np.sign(value)) / 10**decimals

    def export_independent_params(self):
        return {
            "mac_shift": np.array(self._multiplier_shift, np.float32),
            "shift_delta": np.array(self.shift_delta, np.float32),
            "zero_point_feed_repeat": np.array(self.feed_repeat, np.float32),
            "zero_point_compensation_added": bool(self._zp_comp_added),
        }

    def import_independent_params(self, params):
        self._multiplier_shift = params["mac_shift"]
        self.shift_delta = params["shift_delta"]
        self.feed_repeat = params["zero_point_feed_repeat"]
        self._zp_comp_added = bool(params["zero_point_compensation_added"])

    def export_hw_params(self):
        zp_feed_repat = [self.feed_repeat] if self.zp_comp_rank != 2 else self.feed_repeat
        zero = np.array(self.input_zero_points[0], dtype=np.uint16)
        if zero.ndim == 0:
            zp_x_vals = np.array([zero, zero])
        else:
            # TODO Add support for groups, and for changes on zeros.
            zp_x_vals = np.array([zero[0], zero[0]])
        # zp_
        return {
            # Mac Must Values
            "kernel": np.array([], np.int8),
            "zp_kernel": np.array([], np.int32),
            "output_stage/mult_shift": np.array(self._multiplier_shift, np.uint8),
            # Matmul Extra
            "zp_feed_repeat": np.array(zp_feed_repat, dtype=np.uint16),
            "zp_x_vals": zp_x_vals,
        }

    @staticmethod
    def _prepare_matmul_input_general(inp, groups, input_windows, input_tile, other_input_tile):
        _, height, width, features = inp.shape
        input_groups = groups // input_tile
        features = features // input_groups
        window_height, window_width, _ = input_windows

        data_inp = tf.reshape(
            inp,
            [
                -1,
                window_height,
                height // window_height,
                window_width,
                width // window_width,
                input_groups // other_input_tile,
                other_input_tile,
                features,
            ],
        )
        data_inp = tf.transpose(data_inp, [0, 1, 3, 5, 6, 2, 4, 7])
        return tf.reshape(
            data_inp,
            [
                -1,
                window_height,
                window_width,
                input_groups // other_input_tile,
                other_input_tile,
                (height // window_height) * (width // window_width),
                features,
            ],
        )

    def _prepare_matmul_data_input(self, inp, groups, input_windows):
        return self._prepare_matmul_input_general(
            inp, groups, input_windows, self._input_tiles[0][2], self._input_tiles[1][2]
        )

    def _prepare_matmul_weights_input(
        self, inp, groups, is_transposed, zp_comp_added, input_windows, online_zp_compensation=False
    ):
        if online_zp_compensation and zp_comp_added and not self.bit_exact:
            inp = self._calculate_zp_compensation(inp, groups // self._input_tiles[1][2])

        weights_inp = self._prepare_matmul_input_general(
            inp, groups, input_windows, self._input_tiles[1][2], self._input_tiles[0][2]
        )

        if zp_comp_added:
            split = -1 * self.zp_comp_rank
            zp_comp = weights_inp[..., split:]
            if is_transposed:
                zp_comp = tf.transpose(zp_comp, [0, 1, 2, 3, 4, 6, 5])
            weights_inp = weights_inp[..., :split]
        else:
            zp_comp = 0

        if is_transposed:
            weights_inp = tf.transpose(weights_inp, [0, 1, 2, 3, 4, 6, 5])
        return weights_inp, zp_comp

    def _prepare_matmul_output(self, inp, matmul_result, groups, input_windows):
        _, height, width, _ = inp.shape
        output_feratures = matmul_result.shape[-1]
        window_height, window_width, _ = input_windows

        matmul_result = tf.reshape(
            matmul_result,
            [
                -1,
                window_height,
                window_width,
                groups,
                (height // window_height),
                (width // window_width),
                output_feratures,
            ],
        )
        matmul_result = tf.transpose(matmul_result, [0, 1, 4, 2, 5, 3, 6])
        return tf.reshape(matmul_result, [-1, height, width, output_feratures * groups])

    def _calculate_zp_compensation(self, inp, groups):
        """
        this function calculates the exact zp compensation online, and replaces the "evaluated" one with the exact one.
        """
        slice_size = inp.shape[-1] // groups
        concat_input = []
        for ind in range(groups):
            tensor_sel = inp[:, :, :, ind * slice_size : ind * slice_size + slice_size - self.zp_comp_rank]
            tensor_sum = -tf.reduce_sum(tensor_sel, axis=-1, keepdims=True)
            if self.zp_comp_rank == 2:
                vals = tf.concat([tensor_sum, np.zeros(tensor_sum.shape)], axis=-1)
            else:
                vals = tensor_sum
            vals = vals / self.feed_repeat
            concat_input.append(tensor_sel)
            concat_input.append(vals)
        inp_new = tf.concat(concat_input, axis=-1)
        return inp_new

    def _calculate_matmul(self, inputs, zp_comp=0):
        in0 = inputs[0]  # data
        in1 = inputs[1]  # weight
        if self._zp_comp_added:
            zp_comp = tf.expand_dims(zp_comp, axis=-3)
            feed_repeat = tf.cast(tf.reshape(self.feed_repeat, [1, 1, 1, 1, 1, 1, -1, 1]), zp_comp.dtype)
            residue = tf.reduce_sum(self.input_zero_points[0] * zp_comp * feed_repeat, axis=-2)
        else:
            residue = 0
        return tf.matmul(in0, in1) + residue

    def _calculate_matmul_hw(self, inputs, zp_comp=0, shift=0, hsim=None, accumulator_size=None):
        if self._zp_comp_added:
            zp_comp = tf.expand_dims(zp_comp, axis=-3)
            feed_repeat = tf.cast(tf.reshape(self.feed_repeat, [1, 1, 1, 1, 1, 1, -1, 1]), tf.float32)
            residue = tf.reduce_sum(
                tf.cast(self.output_lossy_element(self.input_zero_points[0] * zp_comp / 2**shift), tf.float32)
                * feed_repeat,
                axis=-2,
            )
        else:
            residue = 0
        if hsim is not None:
            hsim_op = hsim.h_matmul
            in0 = tf.cast(inputs[0], self.INT_TYPE_TF)
            in1 = tf.cast(inputs[1], self.INT_TYPE_TF)
            mul = hsim_op(
                in0,
                in1,
                shift,
                accumulator_size=accumulator_size,
                use_fp16_acc=False,
                name="matmul",
            )
            return mul + tf.cast(residue, self.INT_TYPE_TF)
        in0 = tf.expand_dims(inputs[0], axis=-1)  # data
        in1 = tf.expand_dims(inputs[1], axis=-3)  # weight
        mul = self.output_lossy_element(in0 * in1 / 2**shift)
        return tf.reduce_sum(mul, axis=-2) + residue

    def _compute_output_shape(self, input_shape):
        input_shape0 = input_shape[0]
        input_shape1 = input_shape[1]
        number_of_windows = np.prod(self._input_windows)
        if self._transpose_matmul_input:
            output_features = input_shape1[1] * input_shape1[2] // number_of_windows
        else:
            weight_groups = self.groups // self._input_tiles[1][2]
            if self._zp_comp_added:
                output_features = input_shape1[-1] // weight_groups - self.zp_comp_rank
            else:
                output_features = input_shape1[-1] // weight_groups

        return [*input_shape0[:-1], output_features * self.groups]

    def create_weight_quant_element(self, **kwargs):
        pass

    @property
    def bit_exact_supported(self) -> bool:
        """This layer supports bit exact emulation."""
        return True

    @property
    def groups(self):
        return self._groups
