import inspect
from functools import wraps
from typing import NamedTuple

import numpy as np

from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_add import HailoElementwiseAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_mult import HailoElementwiseMult
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_mult_on_mac import HailoElementwiseMultOnMac
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_sub import HailoElementwiseSub
from hailo_model_optimization.acceleras.hailo_layers.hailo_matmul import HailoMatmul
from hailo_model_optimization.acceleras.hailo_layers.hailo_normalization import HailoNormalization
from hailo_model_optimization.acceleras.hailo_layers.hailo_reduce_max import HailoReduceMax
from hailo_model_optimization.acceleras.hailo_layers.hailo_reduce_sum import HailoReduceSum
from hailo_model_optimization.acceleras.hailo_layers.hailo_resize_nearest_neighbor import HailoResizeNearestNeighbor
from hailo_model_optimization.acceleras.hailo_layers.hailo_softmax_mars import HailoSoftmaxMars
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_base import CommandMeta
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerTranslationConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationFitPolicy,
    ActivationType,
    OptimizationTarget,
    ThreeWayPolicy,
)
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class SoftmaxBlock(NamedTuple):
    """a generic softmax"""

    matmul: str
    mull: str
    reduce_max: str
    resize1: str
    ew_sub: str
    reduce_sum: str
    resize2: str
    ew_mult: str


def check_successor_name_is_not_none(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        # Get the argument names of the function being decorated
        arg_names = inspect.getfullargspec(func).args

        # Find the position of 's'
        if "successor_name" in arg_names:
            s_index = arg_names.index("successor_name")  # Subtract 1 because 'self' is usually the first argument
        else:
            s_index = None

        # Get the value of 'successor_name' from args or kwargs
        if s_index is not None and s_index < len(args):
            successor_name = args[s_index]
        else:
            successor_name = kwargs.get("successor_name", "")

        # return False if successor_name is None
        if successor_name is None:
            return False

        # Call the decorated function
        return func(*args, **kwargs)

    return wrapper


class SmartSoftmaxStats(OptimizationAlgorithm):
    """
    This class is responsible for changing the softmax statistics to be based on the reduce_max.
    Instead of using the regular min max limvals which is for every x
        limvals = min(min(x)),  max(max(x))

    We will use:
        limvals = min(max(x)) + np.log(1/2**(bits+1)),  max(max(x))
    """

    def __init__(self, model, model_config, logger_level, logger=None):
        super().__init__(model, model_config, name="Smart Softmax Stats", logger_level=logger_level, logger=logger)

    def _setup(self):
        retval = super()._setup()
        return retval

    def should_skip_algo(self):
        policy = self.get_algo_config().policy
        return policy == ThreeWayPolicy.disabled

    def get_algo_config(self):
        return self._model_config.smart_softmax_stats

    def log_config(self):
        pass

    def _get_input_softmax(self, lname):
        """
        handle the case where the layer is matmul and mull is after it (wasn't fused together)
        """
        successors = self._model.flow.successors_sorted(lname)
        if len(successors) == 1 and isinstance(self._model.layers[successors[0]], HailoElementwiseAdd):
            lname = successors[0]
        successors = self._model.flow.successors_sorted(lname)
        if len(successors) == 1 and isinstance(self._model.layers[successors[0]], HailoNormalization):
            lname = successors[0]
        return lname

    @check_successor_name_is_not_none
    def _check_is_ew_sub(self, successor_name):
        layer = self._model.layers[successor_name]
        return isinstance(layer, HailoElementwiseSub) and layer.act_op.act_name == ActivationType.EXP

    @check_successor_name_is_not_none
    def _check_is_resize(self, successor_name):
        layer = self._model.layers[successor_name]
        return isinstance(layer, HailoResizeNearestNeighbor)

    @check_successor_name_is_not_none
    def _check_is_reduce_max(self, successor_name):
        layer = self._model.layers[successor_name]
        return isinstance(layer, HailoReduceMax)

    @check_successor_name_is_not_none
    def _check_is_ew_mult(self, successor_name):
        layer = self._model.layers[successor_name]
        return (
            isinstance(layer, (HailoElementwiseMult, HailoElementwiseMultOnMac))
            and layer.act_op.act_name == ActivationType.LINEAR
        )

    @check_successor_name_is_not_none
    def _check_is_reduce_sum(self, successor_name):
        layer = self._model.layers[successor_name]
        return isinstance(layer, HailoReduceSum)

    def skip_precision_change(self, list_lnames):
        """
        if we have precision change layer we want to skip it and go to the next layer
        """
        while any("precision_change" in lname for lname in list_lnames):
            to_replace = []
            for l_ind, lname in enumerate(list_lnames):
                if "precision_change" in lname:
                    successors_layers = self._model.flow.successors_sorted(lname)
                    to_replace.append((l_ind, successors_layers))

            for l_ind, successors_layers in reversed(to_replace):
                list_lnames.pop(l_ind)
                list_lnames = list_lnames[:l_ind] + successors_layers + list_lnames[l_ind:]

        return list_lnames

    def skip_ew_mult_mask(self, list_lnames):
        """
        if we have ew_mult_mask layer we want to skip it and go to the next layer
        """
        if len(list_lnames) != 1:
            return list_lnames

        lname = list_lnames[0]

        layer = self._model.layers[lname]
        if not isinstance(layer, (HailoElementwiseMult, HailoElementwiseMultOnMac)):
            return list_lnames
        return self._model.flow.successors_sorted(lname)

    def _get_layer_from_list(self, list_lnames, lname):
        for ln in list_lnames:
            if lname in ln:
                return ln
        return None

    def find_and_build_softmax_blocks(self):
        self.softmax_blocks = []
        matmul_iterator = filter(lambda x: isinstance(self._model.layers[x], HailoMatmul), self._model.layers)
        for lname in matmul_iterator:
            input_matmul = self._get_input_softmax(lname)
            if any(input_matmul == sb.matmul for sb in self.softmax_blocks):
                continue
            # check successors: (ew_sub and reduce_max)
            successors_matmul = self._model.flow.successors_sorted(input_matmul)
            successors_matmul = self.skip_precision_change(successors_matmul)
            successors_matmul = self.skip_ew_mult_mask(successors_matmul)
            successors_matmul = self.skip_precision_change(successors_matmul)

            if len(successors_matmul) != 2:
                continue

            ew_sub = self._get_layer_from_list(successors_matmul, "ew_sub")
            reduce_max = self._get_layer_from_list(successors_matmul, "reduce_max")

            if not (self._check_is_ew_sub(ew_sub) and self._check_is_reduce_max(reduce_max)):
                continue

            successors_reduce_max = self._model.flow.successors_sorted(reduce_max)
            successors_reduce_max = self.skip_precision_change(successors_reduce_max)
            if len(successors_reduce_max) != 1:
                continue

            resize1 = self._get_layer_from_list(successors_reduce_max, "resize")
            if successors_reduce_max[0] == ew_sub:  # reduce max goes straight to ew_sub without resize
                resize1 = None
            elif not self._check_is_resize(resize1):
                continue

            mull = None  # mull currently doest exist (we hope!)
            softmax_block = SoftmaxBlock(
                matmul=input_matmul,
                mull=mull,
                reduce_max=reduce_max,
                resize1=resize1,
                ew_sub=ew_sub,
                reduce_sum=None,
                resize2=None,
                ew_mult=None,
            )
            self._logger.debug(f"found softmax block {input_matmul}")

            self.softmax_blocks.append(softmax_block)

    def _run_int(self):
        if self.optimization_target != OptimizationTarget.MARS:
            self.find_and_build_softmax_blocks()

            for softmax_block in self.softmax_blocks:
                self.apply_smart_softmax_range(softmax_block)
                self.disable_exponent_activation_fitting(softmax_block)
        else:
            self._run_int_mars()

    def _get_softmax_input_name(self, lname):
        return self._model.flow.predecessors_sorted(lname)[0]

    def _change_matmul_stats_layer(self, matmul_layer, new_min, groups):
        # the is a vevtor of shap (1, channels = (channnel_per_group*groups))
        out_stats_of_matmul = matmul_layer.get_output_stats()[0]
        # Get mins based on groups
        stats_group_min = out_stats_of_matmul.min.reshape(groups, -1).min(axis=1)

        best_min = np.maximum(new_min, stats_group_min)
        repates = out_stats_of_matmul.min.size / groups
        out_stats_of_matmul.min[...] = best_min.repeat(repates)

    def _change_softmax_stats_layer(self, softmax_layer, new_min, groups):
        softmax_layer.change_softmax_stats_layer(new_min, groups)

    def _run_int_mars(self):
        self.softmax_blocks = []
        softmax_iterator = filter(lambda x: isinstance(self._model.layers[x], HailoSoftmaxMars), self._model.layers)
        for lname in softmax_iterator:
            matmul_layer_n = self._get_softmax_input_name(lname)
            matmul_layer = self._model.layers[matmul_layer_n]
            softmax_layer = self._model.layers[lname]

            new_min, groups = softmax_layer.get_softmax_new_range()
            self._change_matmul_stats_layer(matmul_layer, new_min, groups)
            self._change_softmax_stats_layer(softmax_layer, new_min, groups)

    def _get_groups(self, softmax_block: SoftmaxBlock):
        if isinstance(self._model.layers[softmax_block.matmul], HailoMatmul):
            groups = self._model.layers[softmax_block.matmul].groups
        elif isinstance(
            self._model.layers[softmax_block.matmul], (HailoElementwiseAdd, HailoNormalization)
        ) and isinstance(
            self._model.layers[self._model.flow.predecessors_sorted(softmax_block.matmul)[0]], HailoMatmul
        ):
            pred = self._model.flow.predecessors_sorted(softmax_block.matmul)[0]
            groups = self._model.layers[pred].groups
        else:
            raise ValueError(
                f"Attmpted to accses the matmul layer of a softmax block at the layer {softmax_block.matmul} and failed"
            )
        return groups

    def _get_reduce_max_min(self, softmax_block: SoftmaxBlock):
        ew_sub_succ = self._model.flow.successors_sorted(softmax_block.ew_sub)
        ew_sub_succ = self.skip_precision_change(ew_sub_succ)
        input_bits_ = np.min(
            [el.bits for succ in ew_sub_succ for el in self._model.layers[succ].get_input_lossy_elements()]
        )

        self._logger.debug(f"on layer {softmax_block.ew_sub}  bittttttts {input_bits_} ")
        output_stats_reduce_max = self._model.layers[softmax_block.reduce_max].get_output_stats()[0]
        groups = self._get_groups(softmax_block)
        return self._calc_new_min(output_stats_reduce_max.min, input_bits_, groups)

    def apply_smart_softmax_range(self, softmax_block: SoftmaxBlock):
        """
        Apply min on all maximums of softmax base on reduce_max layer
        for every vector x we reduce from it its max value and only then pass it to softmax.
        So actually the main value that is passed is the x- max(x).

        Instead of using the regular min max limvals which is for every x:
            limvals = min(min(x)),  max(max(x))

        We will use:
            limvals = min(max(x)) + np.log(1/2**(bits+1)),  max(max(x))

        """

        new_min = self._get_reduce_max_min(softmax_block)
        # the is a vevtor of shap (1, channels = (channnel_per_group*groups))
        matmul_layer = self._model.layers[softmax_block.matmul]
        out_stats_of_input = matmul_layer.get_output_stats()[0]

        if softmax_block.mull is not None:
            mull_layer = self._model.layers[softmax_block.mull]
            new_min_matmul = new_min / mull_layer.kernel.numpy()[0][0][0][0]
        else:
            new_min_matmul = new_min

        # Get mins based on groups
        groups = self._get_groups(softmax_block)
        stats_group_min = out_stats_of_input.min.reshape(groups, -1).min(axis=1)
        best_min = np.maximum(new_min_matmul, stats_group_min)

        self.apply_range(softmax_block, best_min)

    @staticmethod
    def _calc_new_min(min_of_reduce_max, bits, groups: int = 1):
        """
        min_of_reduce_max - a vector of the min  output stats by channel(groups) of reduce_max layer
        bits- number of bits in the output of matmul.
        """
        to_reduce = np.log(1 / 2 ** (bits + 1))
        return np.min(min_of_reduce_max.reshape(groups, -1), axis=1) + to_reduce

    def apply_range(self, softmax_block: SoftmaxBlock, new_min: np.ndarray, new_max: np.ndarray = None):
        matmul_layer = self._model.layers[softmax_block.matmul]
        matmul_layer.keep_original_output_stats()
        limvals_before = matmul_layer.get_output_limvals()[0]

        # the is a vevtor of shap (1, channels = (channnel_per_group*groups))
        out_stats_of_input = matmul_layer.get_output_stats()[0]

        # Get mins based on groups
        groups = self._get_groups(softmax_block)
        out_stats_of_input.min[...] = new_min.repeat(out_stats_of_input.min.size / groups)
        if new_max is not None:
            out_stats_of_input.max[...] = new_max.repeat(out_stats_of_input.min.size / groups)

        limvals_after = matmul_layer.get_output_limvals()[0]
        self._logger.debug(
            f"on layer {matmul_layer.full_name} -set limvals to {limvals_after}  instead of {limvals_before} ",
        )

        # check is successor is precision change
        successors = self._model.flow.successors_sorted(matmul_layer.full_name)

        for successor in successors:
            if "precision_change" in successor:
                precision_change = self._model.layers[successor]
                precision_change.keep_original_output_stats()
                input_stats_pc = precision_change.get_input_stats()[0]
                input_stats_pc.min[...] = out_stats_of_input.min.copy()
                input_stats_pc.max[...] = out_stats_of_input.max.copy()
                output_stats_pc = precision_change.get_output_stats()[0]
                output_stats_pc.min[...] = out_stats_of_input.min.copy()
                output_stats_pc.max[...] = out_stats_of_input.max.copy()
        successors = self.skip_precision_change(successors)
        for successor in successors:
            if "softmax_mask" in successor:
                mask_layer = self._model.layers[successor]
                mask_layer.keep_original_output_stats()
                input_stats_pc = mask_layer.get_input_stats()[0]
                input_stats_pc.min[...] = out_stats_of_input.min.copy()
                input_stats_pc.max[...] = out_stats_of_input.max.copy()
                output_stats_pc = mask_layer.get_output_stats()[0]
                output_stats_pc.min[...] = out_stats_of_input.min.copy()
                output_stats_pc.max[...] = out_stats_of_input.max.copy()
        successors = self.skip_ew_mult_mask(successors)
        for successor in successors:
            if "precision_change" in successor:
                precision_change = self._model.layers[successor]
                precision_change.keep_original_output_stats()
                input_stats_pc = precision_change.get_input_stats()[0]
                input_stats_pc.min[...] = out_stats_of_input.min.copy()
                input_stats_pc.max[...] = out_stats_of_input.max.copy()
                output_stats_pc = precision_change.get_output_stats()[0]
                output_stats_pc.min[...] = out_stats_of_input.min.copy()
                output_stats_pc.max[...] = out_stats_of_input.max.copy()

        reduce_max_layer = self._model.layers[softmax_block.reduce_max]
        input_stats_rm = reduce_max_layer.get_input_stats()[0]
        input_stats_rm.min[...] = out_stats_of_input.min.copy()
        input_stats_rm.max[...] = out_stats_of_input.max.copy()

        ew_sub_layer = self._model.layers[softmax_block.ew_sub]
        input_stats_ews = ew_sub_layer.get_input_stats()[0]
        input_stats_ews.min[...] = out_stats_of_input.min.copy()
        input_stats_ews.max[...] = out_stats_of_input.max.copy()

        if softmax_block.resize1 is not None:
            resize1_layer = self._model.layers[softmax_block.resize1]
            ew_sub_input1_max_stats = resize1_layer.get_output_stats()[0].max.copy()
        else:  # in case the reduce_max goes straight to ew_sub (with input repeats inside ew_sub)
            repeats = ew_sub_layer.input_repeats[1][-1]
            ew_sub_input1_max_stats = np.repeat(reduce_max_layer.get_output_stats()[0].max.copy(), repeats)
        preact_stats_ews = ew_sub_layer.get_preact_stats()[0]
        preact_stats_ews.min[...] = np.maximum(
            out_stats_of_input.min.copy() - ew_sub_input1_max_stats,
            preact_stats_ews.min,
        )

    def disable_exponent_activation_fitting(self, softmax_block: SoftmaxBlock):
        """
        Set the exponent activation fitting to disabled only if the user didn't set it to manually.
        This is done to save time as the predefined pw values in softmax block are almost identical to the ones we get
        when we do use fitting.
        """
        translation_layer_config = self.finalize_layer_cfg(self._model_config.translation_config.layers)
        self._model_config.translation_config.layers[softmax_block.ew_sub] = translation_layer_config.get(
            softmax_block.ew_sub, LayerTranslationConfig.get_default()
        )
        meta = self._model_config.translation_config.layers[softmax_block.ew_sub].meta
        if meta is None:
            meta = dict()
        if "activation_fit" not in meta.keys():
            meta["activation_fit"] = CommandMeta(line=-1, command="", is_glob=False)
            self._model_config.translation_config.layers[
                softmax_block.ew_sub
            ].activation_fit = ActivationFitPolicy.disabled
            self._model_config.translation_config.layers[softmax_block.ew_sub].meta = meta

    def finalize_global_cfg(self, algo_config):
        pass

    def _get_valid_layer_cfg(self, lname, cfg):
        return cfg
