"""
This module implements the QuaRot algorithm for an acceleras model
https://arxiv.org/abs/2404.00456 (QuaRot: Outlier-Free 4-Bit Inference in Rotated LLMs)
https://arxiv.org/abs/2405.16406 (SpinQuant: LLM quantization with learned rotations)
"""

import numpy as np

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv import BaseHailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_io import HailoInputLayer
from hailo_model_optimization.acceleras.model.preprocess.preprocess import add_preprocess
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationType,
    FormatConversionType,
    OrthoGenType,
    QuantizationAlgorithms,
    ThreeWayPolicy,
)
from hailo_model_optimization.acceleras.utils.modification_tracker.modification_tracker_utils import (
    MatrixMultiplicationTracker,
)
from hailo_model_optimization.acceleras.utils.rotation_utils import get_rotation_matrix
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class QuaRot(OptimizationAlgorithm):
    """
    This module implements the QuaRot algorithm for an acceleras model.
    """

    def __init__(self, model, model_config, logger_level, logger=None):
        super().__init__(model, model_config, name="QuaRot", logger_level=logger_level, logger=logger)
        cfg = self.get_algo_config()
        self._should_equalize_inputs = cfg.equalize_inputs == ThreeWayPolicy.enabled
        self._should_equalize_outputs = cfg.equalize_outputs == ThreeWayPolicy.enabled
        self._preprocess_changed = False
        self._postprocess_changed = False

    def _setup(self):
        equiv_iterator = self._model.iter_equiv_sets(QuantizationAlgorithms.quarot)
        self._equiv_sets = []
        for equiv_set in equiv_iterator:
            if not self._should_skip_equiv_set(equiv_set):
                self._equiv_sets.append(equiv_set)

    def should_skip_algo(self):
        skip = self.get_algo_config().policy != ThreeWayPolicy.enabled
        if not skip:
            self._logger.warning("QuaRot is enabled. Note that this is an experimental feature.")
        return skip

    def get_algo_config(self):
        return self._model_config.quarot

    def log_config(self):
        pass

    def _run_int(self):
        for equiv_set in self._equiv_sets:
            self._apply_rotation(equiv_set)

        if self._preprocess_changed:
            add_preprocess(self._model)

        if self._preprocess_changed or self._postprocess_changed:
            self._logger.warning("QuaRot has changed the model's preprocess/postprocess configuration.")

    def finalize_global_cfg(self, algo_config):
        pass

    def _apply_rotation(self, equiv_set):
        all_groups = {
            equiv_layer.layer.groups
            for equiv_layer in equiv_set.matmul + equiv_set.consumers
            if equiv_layer.layer.groups > 1
        }
        if len(all_groups) > 0:
            groups = all_groups.pop()
        else:
            groups = 1

        size = equiv_set.producers[0].layer.output_shape[-1]
        ortho_gen_type = (
            OrthoGenType.HADAMARD
            if np.log2(size // groups) == int(np.log2(size // groups))
            else OrthoGenType.PARTIAL_RANDOM
        )
        rot = get_rotation_matrix(size, ortho_gen_type=ortho_gen_type, groups=groups)
        modification_rotation_key = self._modifications_meta_data.add_modification_param("rotation", rot)

        for equiv_layer in equiv_set.producers:
            layer = equiv_layer.layer
            if isinstance(layer, HailoInputLayer):
                self._add_preprocess(layer, rot, modification_rotation_key)
                self._preprocess_changed = True
            else:
                self._apply_rotation_source(layer, rot, modification_rotation_key)

        for equiv_layer in equiv_set.consumers:
            layer = equiv_layer.layer
            self._apply_rotation_consumer(layer, rot, modification_rotation_key)

        for equiv_layer in equiv_set.outputs:
            self._postprocess_changed = True

    def _apply_rotation_source(self, layer: BaseHailoConv, rot: np.ndarray, modification_rotation_key: str):
        self._modifications_meta_data.append(
            layer.full_name,
            MatrixMultiplicationTracker(
                kernel_key=modification_rotation_key,
                transpose=False,
                apply_on_input=False,
            ),
        )
        kernel = layer.export_native_kernel()
        bias = layer.export_native_bias()
        new_kernel = kernel @ rot
        new_bias = bias @ rot
        layer.import_native_kernel(new_kernel)
        layer.import_native_bias(new_bias)

    def _apply_rotation_consumer(self, layer: BaseHailoConv, rot: np.ndarray, modification_rotation_key: str):
        self._modifications_meta_data.append(
            layer.full_name,
            MatrixMultiplicationTracker(
                kernel_key=modification_rotation_key,
                transpose=True,
                apply_on_input=True,
            ),
        )
        if layer.groups > 1:
            shape = rot.shape[0]
            consumer_rot = rot[: shape // layer.groups, : shape // layer.groups]
        else:
            consumer_rot = rot
        kernel = layer.export_native_kernel()
        new_kernel = consumer_rot.T @ kernel
        layer.import_native_kernel(new_kernel)

    def _add_preprocess(self, layer: HailoInputLayer, rot: np.ndarray, modification_rotation_key: str):
        if layer.conversion_type == FormatConversionType.embedding:
            layer.conversion_weights.embed = layer.conversion_weights.embed @ rot
            return
        layer.conversion_type = FormatConversionType.rotation
        layer.emulate_conversion = True
        layer._hn_element["conversion_type"] = FormatConversionType.rotation.name
        layer._hn_element["emulate_conversion"] = True
        layer.conversion_weights.rotation = rot

    def _should_skip_equiv_set(self, equiv_set):
        unsupported_activations = False
        if len(equiv_set.unsupported) == 0 and len(equiv_set.skip) == 0:
            for lname, out_degree in equiv_set.equiv_set_flow.out_degree:
                layer = self._model.layers[lname]
                if (
                    out_degree > 0
                    and layer.has_activation
                    and layer.get_activation_name() not in [ActivationType.LINEAR]
                ):
                    unsupported_activations = True
        source_input_layers = [
            layer_name
            for layer_name in equiv_set.source_layers
            if isinstance(self._model.layers[layer_name], HailoInputLayer)
        ]
        skip_inputs = len(source_input_layers) > 0 and not (self._should_equalize_inputs)

        skip_outputs = bool(equiv_set.outputs) and not (self._should_equalize_outputs)

        groups = {
            equiv_layer.layer.groups
            for equiv_layer in equiv_set.matmul + equiv_set.consumers
            if equiv_layer.layer.groups > 1
        }
        skip_groups = len(groups) > 1

        should_skip = unsupported_activations or equiv_set.unsupported or skip_inputs or skip_outputs or skip_groups
        return should_skip
