import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import AdapterType, TrackerStage, TrackerType
from hailo_model_optimization.acceleras.utils.acceleras_exceptions import AccelerasValueError
from hailo_model_optimization.acceleras.utils.modification_tracker.modification_tracker_utils import (
    NPZ,
    AlgoModificationTracker,
    BaseModificationTracker,
)


class BaseModificationApplier:
    def __init__(self, tracker: BaseModificationTracker):
        self._validate_tracker(tracker)
        self.tracker = tracker

    @staticmethod
    def _validate_tracker(tracker: BaseModificationTracker):
        raise NotImplementedError

    def apply(self, adapter_type: AdapterType, modification_params: NPZ, native_params: NPZ, *args, **kwargs) -> NPZ:
        if adapter_type == AdapterType.BASE:
            return self.apply_base(modification_params, native_params, *args, **kwargs)
        elif adapter_type == AdapterType.LORA:
            return self.apply_lora(modification_params, native_params, *args, **kwargs)
        else:
            raise AccelerasValueError(f"Adapter type {adapter_type} is not supported")

    def apply_base(self, modification_params: NPZ, native_params: NPZ, *args, **kwargs) -> NPZ:
        raise NotImplementedError

    def apply_lora(self, modification_params: NPZ, native_params: NPZ, *args, **kwargs) -> NPZ:
        raise NotImplementedError

    def should_apply(self, stage: TrackerStage) -> bool:
        if stage == TrackerStage.FP_OPTIMIZE:
            return self.tracker.stage == TrackerStage.FP_OPTIMIZE
        if stage == TrackerStage.QUANTIZE:
            return self.tracker.stage in {TrackerStage.FP_OPTIMIZE, TrackerStage.QUANTIZE}
        return False


class FoldNormalizationApplier(BaseModificationApplier):
    @staticmethod
    def _validate_tracker(tracker: BaseModificationTracker):
        if tracker.tracker_type != TrackerType.FOLD_NORMALIZATION:
            raise AccelerasValueError(f"Tracker {tracker} is not of type {TrackerType.FOLD_NORMALIZATION}")

    def apply_base(self, modification_params, native_params, *args, **kwargs):
        if self.tracker.apply_on_input:
            native_params[f"{self.tracker.lname}/bias:0"] += np.matmul(
                modification_params[self.tracker.bias_key], native_params[f"{self.tracker.lname}/kernel:0"][0, 0]
            )
            native_params[f"{self.tracker.lname}/kernel:0"] *= modification_params[self.tracker.kernel_key]
        else:
            native_params[f"{self.tracker.lname}/kernel:0"] *= modification_params[self.tracker.kernel_key].transpose(
                0, 1, 3, 2
            )
            native_params[f"{self.tracker.lname}/bias:0"] *= modification_params[self.tracker.kernel_key][0, 0, :, 0]
            native_params[f"{self.tracker.lname}/bias:0"] += modification_params[self.tracker.bias_key]
        return native_params

    def apply_lora(self, modification_params, native_params, *args, **kwargs):
        if self.tracker.apply_on_input:
            native_params[f"{self.tracker.lname}_lora_down/bias:0"] += np.matmul(
                modification_params[self.tracker.bias_key],
                native_params[f"{self.tracker.lname}_lora_down/kernel:0"][0, 0],
            )
            native_params[f"{self.tracker.lname}_lora_down/kernel:0"] *= modification_params[self.tracker.kernel_key]
        else:
            native_params[f"{self.tracker.lname}_lora_up/kernel:0"] *= modification_params[
                self.tracker.kernel_key
            ].transpose(0, 1, 3, 2)
            native_params[f"{self.tracker.lname}_lora_up/bias:0"] *= modification_params[self.tracker.kernel_key][
                0, 0, :, 0
            ]
            native_params[f"{self.tracker.lname}_lora_up/bias:0"] += modification_params[self.tracker.bias_key]
        return native_params


class MatrixMultiplicationApplier(BaseModificationApplier):
    @staticmethod
    def _validate_tracker(tracker: BaseModificationTracker):
        if tracker.tracker_type != TrackerType.MATRIX_MULTIPLICATION:
            raise AccelerasValueError(f"Tracker {tracker} is not of type {TrackerType.MATRIX_MULTIPLICATION}")

    def apply_base(self, modification_params, native_params, *args, **kwargs):
        matrix = modification_params[self.tracker.kernel_key]
        if self.tracker.transpose:
            matrix = matrix.T
        if self.tracker.apply_on_input:
            native_params[f"{self.tracker.lname}/kernel:0"] = matrix @ native_params[f"{self.tracker.lname}/kernel:0"]
        else:
            native_params[f"{self.tracker.lname}/kernel:0"] = native_params[f"{self.tracker.lname}/kernel:0"] @ matrix
            native_params[f"{self.tracker.lname}/bias:0"] = native_params[f"{self.tracker.lname}/bias:0"] @ matrix
        return native_params

    def apply_lora(self, modification_params, native_params, *args, **kwargs):
        matrix = modification_params[self.tracker.kernel_key]
        if self.tracker.transpose:
            matrix = matrix.T
        if self.tracker.apply_on_input:
            native_params[f"{self.tracker.lname}_lora_down/kernel:0"] = (
                matrix @ native_params[f"{self.tracker.lname}_lora_down/kernel:0"]
            )
        else:
            native_params[f"{self.tracker.lname}_lora_up/kernel:0"] = (
                native_params[f"{self.tracker.lname}_lora_up/kernel:0"] @ matrix
            )
            native_params[f"{self.tracker.lname}_lora_up/bias:0"] = (
                native_params[f"{self.tracker.lname}_lora_up/bias:0"] @ matrix
            )
        return native_params


class GatherApplier(BaseModificationApplier):
    @staticmethod
    def _validate_tracker(tracker: BaseModificationTracker):
        if tracker.tracker_type != TrackerType.GATHER:
            raise AccelerasValueError(f"Tracker {tracker} is not of type {TrackerType.GATHER}")

    def apply_base(self, modification_params, native_params, *args, **kwargs):
        indices = modification_params[self.tracker.indices_key]
        if self.tracker.apply_on_input:
            native_params[f"{self.tracker.lname}/kernel:0"] = native_params[f"{self.tracker.lname}/kernel:0"][
                :, :, indices, :
            ]
        else:
            native_params[f"{self.tracker.lname}/kernel:0"] = native_params[f"{self.tracker.lname}/kernel:0"][
                :, :, :, indices
            ]
            native_params[f"{self.tracker.lname}/bias:0"] = native_params[f"{self.tracker.lname}/bias:0"][indices]
        return native_params

    def apply_lora(self, modification_params, native_params, *args, **kwargs):
        indices = modification_params[self.tracker.indices_key]
        if self.tracker.apply_on_input:
            native_params[f"{self.tracker.lname}_lora_down/kernel:0"] = native_params[
                f"{self.tracker.lname}_lora_down/kernel:0"
            ][:, :, indices, :]
        else:
            native_params[f"{self.tracker.lname}_lora_up/kernel:0"] = native_params[
                f"{self.tracker.lname}_lora_up/kernel:0"
            ][:, :, :, indices]
            native_params[f"{self.tracker.lname}_lora_up/bias:0"] = native_params[
                f"{self.tracker.lname}_lora_up/bias:0"
            ][indices]
        return native_params


class SplitApplier(BaseModificationApplier):
    @staticmethod
    def _validate_tracker(tracker: BaseModificationTracker):
        if tracker.tracker_type != TrackerType.SPLIT:
            raise AccelerasValueError(f"Tracker {tracker} is not of type {TrackerType.SPLIT}")

    def apply_base(self, modification_params, native_params, *args, **kwargs):
        kernel = native_params.pop(f"{self.tracker.lname}/kernel:0")
        bias = native_params.pop(f"{self.tracker.lname}/bias:0")
        padding_const_value = native_params.pop(f"{self.tracker.lname}/padding_const_value:0", 0)
        no_subgroups = len(self.tracker.mini_convs)
        subgroup_size = kernel.shape[2] // no_subgroups
        for group_i, mini_conv in enumerate(self.tracker.mini_convs):
            native_params[f"{mini_conv}/kernel:0"] = kernel[
                :, :, subgroup_size * group_i : subgroup_size * (group_i + 1), :
            ]
            native_params[f"{mini_conv}/bias:0"] = bias / no_subgroups
            native_params[f"{mini_conv}/padding_const_value:0"] = padding_const_value
        return native_params

    def apply_lora(self, modification_params, native_params, *args, **kwargs):
        if not self.tracker.ew_add_after:
            raise AccelerasValueError(
                f"Currently we don't support applying split tracker in LORA mode without ew_add_after for layer {self.tracker.lname}"
            )
        return native_params


tracker_types_map = {
    TrackerType.FOLD_NORMALIZATION: FoldNormalizationApplier,
    TrackerType.MATRIX_MULTIPLICATION: MatrixMultiplicationApplier,
    TrackerType.GATHER: GatherApplier,
    TrackerType.SPLIT: SplitApplier,
}


class AlgoModificationApplier:
    def __init__(self, modification_tracker: AlgoModificationTracker):
        self.modification_tracker = modification_tracker

    def apply(
        self, lname: str, adapter_type: AdapterType, stage: TrackerStage, native_params: NPZ, *args, **kwargs
    ) -> NPZ:
        trackers = self.modification_tracker.layers.get(lname, [])
        modification_params = self.modification_tracker.get_modification_params()
        for tracker in trackers:
            if tracker.tracker_type not in tracker_types_map:
                raise AccelerasValueError(f"Tracker type {tracker.tracker_type} is not supported")
            applier = tracker_types_map.get(tracker.tracker_type)(tracker)
            if applier.should_apply(stage):
                native_params = applier.apply(adapter_type, modification_params, native_params, *args, **kwargs)
        return native_params
