from typing import List, Tuple

import numpy as np

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_conv import BaseHailoConv
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_concat import HailoConcat
from hailo_model_optimization.acceleras.hailo_layers.hailo_const import HailoConst
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_add import HailoElementwiseAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_element_wise_sub import HailoElementwiseSub
from hailo_model_optimization.acceleras.hailo_layers.hailo_io import HailoInputLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_matmul import HailoMatmul
from hailo_model_optimization.acceleras.hailo_layers.hailo_precision_split_signed import HailoPrecisionSplitSigned
from hailo_model_optimization.acceleras.hailo_layers.hailo_reduce_max import HailoReduceMax
from hailo_model_optimization.acceleras.hailo_layers.hailo_reduce_sum import HailoReduceSum
from hailo_model_optimization.acceleras.hailo_layers.hailo_resize_nearest_neighbor import HailoResizeNearestNeighbor
from hailo_model_optimization.acceleras.hailo_layers.hailo_shortcut import HailoShortcut
from hailo_model_optimization.acceleras.hailo_layers.hailo_softmax_mask import HailoSoftmaxMask
from hailo_model_optimization.acceleras.hailo_layers.hailo_softmax_mask_on_mac import HailoSoftmaxMaskOnMac
from hailo_model_optimization.acceleras.hailo_layers.hailo_standalone_activation import HailoStandaloneActivation
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.model.hailo_model.model_flow import ModelFlow
from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerMatmulEqualizationConfig
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_model import MatmulEqualizationConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    SHIFT_CALCULATE_BUFFER,
    MatmulCorrectionType,
    ThreeWayPolicy,
)
from hailo_model_optimization.acceleras.utils.opt_utils import calculate_shifts, limvals_to_zp_scale
from hailo_model_optimization.algorithms.matmul_correction.correction_blocks import (
    MMCorrectionBlock3,
)
from hailo_model_optimization.algorithms.matmul_equalization.optimal_zp_finder import (
    find_the_best_zp,
    minimize_scale_dof,
    minimize_scale_dof_and_zp,
)
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm
from hailo_model_optimization.algorithms.smart_softmax_stats.smart_softmax_stats import SmartSoftmaxStats, SoftmaxBlock


class UnsuportedLayerError(ValueError):
    def __init__(self, message, layer_name: str):
        super().__init__(message)
        self.layer_name = layer_name


class MatmulEqualization(OptimizationAlgorithm):
    """
    Equalizes the Matmul block (for a transpose Matmul).
    This algorithm will work with the producers of the
    Matmul and set the scales group-wise. It will propagate
    those scales across the paths and set the Matmul to work
    with these scales. Then, it will propagate the resulting
    scales also group-wise until the ew_sup
    """

    SUPPORTED_CORRECTIONS = (
        MatmulCorrectionType.ZP_COMP_BLOCK,
        MatmulCorrectionType.ZP_COMP_BLOCK_2,
        MatmulCorrectionType.ZP_COMP_BLOCK_3,
        MatmulCorrectionType.ZP_COMP_WEIGHTS,
    )

    SOFTMAX_SUPPORTED = (
        HailoMatmul,
        HailoElementwiseSub,
        HailoReduceMax,
        HailoShortcut,
        HailoResizeNearestNeighbor,
        HailoStandaloneActivation,
        HailoSoftmaxMask,
        HailoSoftmaxMaskOnMac,
    )

    def __init__(
        self,
        model: HailoModel,
        model_config: ModelOptimizationConfig,
        logger_level,
        logger,
        **kwargs,
    ):
        super().__init__(
            model, model_config, logger_level=logger_level, logger=logger, name="Matmul Equalization", **kwargs
        )
        self._cfg: MatmulEqualizationConfig = self.get_algo_config()
        self._softmax = SmartSoftmaxStats(model, model_config, logger_level, logger)

    def _setup(self):
        self._softmax._setup()

    def _run_int(self):
        flow = self._model.flow
        layers = self._model.layers
        self._softmax.find_and_build_softmax_blocks()
        for lname in self._cfg.layers.keys():
            if not self.check_validity(lname, self._softmax.softmax_blocks):  # TODO add better check
                continue

            layer: HailoMatmul = layers[lname]
            self.verify_layers(layer)
            block: SoftmaxBlock = next((block for block in self._softmax.softmax_blocks if block.matmul == lname))
            normal_layer, transpose_layer = list(flow.predecessors_sorted(layer.full_name))
            try:
                if layer.zp_correction_type in [
                    MatmulCorrectionType.ZP_COMP_BLOCK_3,
                    MatmulCorrectionType.ZP_COMP_BLOCK_2,
                    MatmulCorrectionType.ZP_COMP_BLOCK,
                ]:
                    mmc = MMCorrectionBlock3(self._model, self._model_config, self._logger, layer.zp_comp_added)
                    transpose_subflow = mmc.get_sub_flow(lname)
                else:
                    transpose_subflow = flow.subgraph(walk_backwards(transpose_layer, self._model, [layer.full_name]))

                normal_subflow = flow.subgraph(walk_backwards(normal_layer, self._model, [layer.full_name]))

            except UnsuportedLayerError as e:
                self._logger.info(
                    f"{self._name}: Could not be apply because it "
                    f"has an unsuported layer {e.layer_name} on its inputs flow"
                )
                return

            self.equalize_input_paths(normal_subflow, transpose_subflow, layer)
            self.matmul_propagation(block)

    # region Abstrac Methods

    def _check_layer_supported(self, layer):
        return (
            isinstance(layer, HailoMatmul)
            and layer.transpose_matmul_input
            and layer.zp_correction_type in self.SUPPORTED_CORRECTIONS
            and layer.groups > 1
        )

    def should_skip_algo(self):
        return False if self._cfg.dict().get("layers", {}) else True

    def get_algo_config(self) -> MatmulEqualizationConfig:
        return self._model_config.matmul_equalization

    def finalize_global_cfg(self, algo_config):
        """
        Finalize the algorithm's config. (values that are not layer specific)
        Can include values verification, fetching data from the other algo's config, etc...
        """
        for layer in self._model.layers.values():
            if self._check_layer_supported(layer):
                algo_config.layers.setdefault(layer.full_name, LayerMatmulEqualizationConfig())
        super().finalize_global_cfg(algo_config)

    def _get_valid_layer_cfg(self, lname, cfg):
        if not self._check_layer_supported(self._model.layers[lname]):
            cfg = {}
        return cfg

    # region Helper Methods
    def equalize_input_paths(self, normal_flow: ModelFlow, transpose_flow: ModelFlow, layer: HailoMatmul):
        factors = np.ones(layer.groups, dtype=np.float32)

        for _ in range(2):
            set_scales_forward(self._model, normal_flow, layer.groups, factors)
            set_scales_forward(self._model, transpose_flow, layer.groups, factors, transpose=True)

            fix_scales_reduce_sum(self._model, transpose_flow, layer.groups, layer, self._logger)

            factors = get_real_limvals(layer)
            if np.allclose(factors, np.ones(layer.groups)):
                break

    def check_validity(self, lname: str, softmax_blocks: List[SoftmaxBlock]):
        policy = self._cfg.layers[lname].policy

        def raise_or_false():
            if policy == ThreeWayPolicy.enabled:
                raise ValueError(
                    f"{self._name}: {lname}: is set to policy enable but does not meet the requiremts to run"
                )
            return False

        if policy in [ThreeWayPolicy.disabled]:
            return raise_or_false()

        model = self._model
        softmax_block = next((block for block in softmax_blocks if block.matmul == lname and block.ew_sub), None)
        if softmax_block is None:
            return raise_or_false()

        sub_flow = model.flow.get_sub_flow(softmax_block.matmul, softmax_block.ew_sub)
        all_legal = all((isinstance(model.layers[ln], self.SOFTMAX_SUPPORTED) for ln in sub_flow.nodes))
        if not all_legal:
            self._logger.warning(
                f"{self._name}: {lname} softmax block have unsupported "
                f"layers {self._name} will not be apply on this layer"
            )
            return raise_or_false()

        # TODO check backwards

        return True

    def verify_layers(self, matmul):
        if isinstance(matmul, HailoMatmul):
            if not matmul.zp_comp_added:
                ValueError(f"{self._name}: Only works on matmul layers that have Zp Comp Added")
            if not matmul.transpose_matmul_input:
                ValueError(f"{self._name}: Only works on matmul layers that Have Transpose input")
        else:
            raise ValueError(f"{self._name}: Only works on matmul layers")

    def matmul_propagation(self, softmax_block: SoftmaxBlock) -> None:
        model = self._model
        config = self.get_algo_config()
        softmax_flow = model.flow.get_sub_flow(softmax_block.matmul, softmax_block.ew_sub)

        matmul: HailoMatmul = model.layers[softmax_block.matmul]
        groups = matmul.groups
        lossy_element = matmul.get_output_lossy_elements()[0]

        # quant_groups = matmul.activation_atomic_op.quantization_groups_num
        scale_a = np.array(matmul.input_scales[0]).reshape((groups, -1)).mean(axis=1)
        scale_b = np.array(matmul.input_scales[1]).reshape((groups, -1)).mean(axis=1)
        pre_act_scale = scale_a * scale_b

        bins = lossy_element.bins_count

        lim_vals = matmul.get_output_limvals()[0]
        candidate_zp, candidate_scales, _ = limvals_to_zp_scale(lim_vals, lossy_element)

        start_dof = np.max(candidate_scales) / np.min(pre_act_scale)

        stats_min, stats_max = matmul.get_group_output_limvals(groups)[0]
        if config.layers[softmax_block.matmul].matmul_bias == ThreeWayPolicy.enabled and matmul.zp_correction_type in [
            MatmulCorrectionType.ZP_COMP_BLOCK_2,
            MatmulCorrectionType.ZP_COMP_BLOCK_3,
        ]:
            scale_dof, new_zp = minimize_scale_dof_and_zp(
                max_vector=stats_max,
                min_vector=stats_min,
                acc_vector=pre_act_scale,
                bins=bins,
                start_zp=[candidate_zp] * groups,
                start_dof=start_dof,
            )
            bias = scale_dof * pre_act_scale * new_zp
            mmc = MMCorrectionBlock3(model, None, None, True)
            mmc.add_bias_to_matmul(softmax_block.matmul, bias)
            stats_min, stats_max = matmul.get_group_output_limvals(groups)[0]
            self._softmax.apply_range(softmax_block, stats_min, stats_max)

        scale_dof, new_zp = minimize_scale_dof(
            max_vector=stats_max,
            min_vector=stats_min,
            acc_vector=pre_act_scale,
            bins=bins,
            start_zp=candidate_zp,
            start_dof=start_dof,
        )

        final_out_s = (scale_dof * pre_act_scale).repeat(matmul.output_shape[-1] / groups)

        matmul.set_output_scale(final_out_s, 0)
        matmul.set_output_zero_point(np.float32(round(new_zp)), 0)
        edges = softmax_flow.toposort_edges()
        model.enforce_constraints(edges)


# region Walk Functions


def walk_backwards(node: str, model: HailoModel, nodes: list) -> List[str]:
    layer = model.layers[node]
    if not isinstance(
        layer,
        (
            BaseHailoConv,
            HailoInputLayer,
            HailoConcat,
            HailoReduceSum,
            HailoStandaloneActivation,
            HailoMatmul,
            HailoConst,
            HailoPrecisionSplitSigned,
            HailoElementwiseAdd,
        ),
    ):
        raise UnsuportedLayerError(f"Layer to walk backwards is not supported: {layer.full_name}", layer.full_name)
    if isinstance(layer, (BaseHailoConv, HailoInputLayer, HailoElementwiseAdd)) and not isinstance(
        layer, HailoPrecisionSplitSigned
    ):
        nodes.append(node)

    else:
        nodes.append(node)
        for pre in model.flow.predecessors_sorted(node):
            walk_backwards(pre, model, nodes)
    return nodes


# region Propagation Functions


def get_real_limvals(matmul: HailoMatmul):
    # quant_groups = matmul.activation_atomic_op.quantization_groups_num
    groups = matmul.groups
    scale_a = np.array(matmul.input_scales[0]).reshape((groups, -1)).mean(axis=1)
    scale_b = np.array(matmul.input_scales[1]).reshape((groups, -1)).mean(axis=1)
    pre_act_candidate = scale_a * scale_b

    limval_min, limval_max = matmul.act_op.get_group_input_limvals(0, matmul.groups)

    lossy_element = matmul.act_op.input_lossy_element
    accumultor_size = lossy_element.max_value
    limvals = np.maximum(np.abs(limval_max), np.abs(limval_min))
    expected_max_output = limvals / pre_act_candidate
    mac_shift, aa, bb = calculate_shifts(
        expected_max_output, lossy_element.bits, SHIFT_CALCULATE_BUFFER * 0.999, return_needed_shift=True
    )

    min_posible_scale = limvals * 2 ** (SHIFT_CALCULATE_BUFFER) / (accumultor_size * 2 ** (np.min(np.round(mac_shift))))
    correc_scales = np.maximum(min_posible_scale, pre_act_candidate)
    corrected_max = limvals / correc_scales
    mac_shift, _, _ = calculate_shifts(
        corrected_max, lossy_element.bits, SHIFT_CALCULATE_BUFFER, return_needed_shift=True
    )
    fix_ratio = np.sqrt(correc_scales / pre_act_candidate, dtype=np.float32)

    return fix_ratio


def set_scales_forward(
    model: HailoModel,
    flow: ModelFlow,
    groups: int,
    factors: np.ndarray,
    transpose=False,
) -> None:
    in_node = flow.get_sources()[0]
    in_layer = model.layers[in_node]
    scales, zp_points = get_scales_output(in_layer, groups, factors, transpose)
    in_layer.set_output_scale(scales, 0)
    in_layer.set_output_zero_point(np.float32(zp_points), 0)
    if isinstance(in_layer, HailoInputLayer):
        in_layer.set_input_scale(scales, 0)
        in_layer.set_input_zero_point(round(zp_points), 0)
    for lname in flow.nodes():
        layer = model.layers[lname]
        if isinstance(layer, HailoStandaloneActivation):
            dof = 2 ** (layer.get_input_lossy_elements()[0].bits - layer.get_output_lossy_elements()[0].bits)
            layer.output_scale_scalar_dof = dof

    edges = flow.toposort_edges()
    model.enforce_constraints(edges, training=False, create_ratio=False)


def get_scales_output(layer: BaseHailoLayer, groups: int, factors: np.ndarray, transpose=False) -> Tuple[np.array, int]:
    lossy_element = layer.get_output_lossy_elements()[0]
    bins = lossy_element.bins_count
    min_vals, max_vals = layer.get_group_output_limvals(groups)[0]
    # We can do this for every quantization group
    if not lossy_element.signed:
        if transpose:
            scales = (np.max([np.abs(min_vals), np.abs(max_vals)], axis=0)) * 2 / bins
            zp = (bins + 1) // 2
        else:
            candidate_zp, candidate_scale, _ = limvals_to_zp_scale(layer.get_output_limvals()[0], lossy_element)
            scales, zp = find_the_best_zp(
                xmin=min_vals,
                xmax=max_vals,
                bins=bins,
                start_scales=np.repeat(candidate_scale, max_vals.size),
                start_zp=candidate_zp,
            )

    else:
        scales = np.array(
            [
                scale
                for scale in map(
                    lambda lin_vals: limvals_to_zp_scale(lin_vals, lossy_element)[1], zip(min_vals, max_vals)
                )
            ]
        )
        zp = 0
    scales = scales * factors
    scales = scales.repeat(layer.output_shape[-1] / groups)
    return scales, zp


def fix_scales_reduce_sum(model: HailoModel, flow: ModelFlow, groups: int, matmul: HailoMatmul, logger):
    if matmul.zp_correction_type is MatmulCorrectionType.ZP_COMP_WEIGHTS:
        return
    concat_node = list((filter(lambda x: isinstance(model.layers[x], HailoConcat), flow.toposort())))[0]
    edges = flow.toposort_edges()

    linear, lv1 = [model.layers[la] for la in flow.predecessors_sorted(concat_node)]
    if matmul.zp_correction_type is MatmulCorrectionType.ZP_COMP_BLOCK:
        lv1.set_output_scale(np.array(linear.output_scale).reshape(groups, -1).mean(axis=1), 0)

    elif matmul.zp_correction_type is MatmulCorrectionType.ZP_COMP_BLOCK_2:
        reduce_sum = model.layers[list(flow.predecessors(lv1.full_name))[0]]
        reduce_sum.set_output_scale(np.array(linear.output_scale).reshape((groups, -1)).mean(axis=1), 0)
        lv1.set_input_scale(reduce_sum.output_scales[0], 0)
        lv1.enforce_io_encoding()

    elif matmul.zp_correction_type is MatmulCorrectionType.ZP_COMP_BLOCK_3:
        lys = model.layers
        mmc = MMCorrectionBlock3(model, None, None, True)
        bn = mmc.get_block_names(matmul.full_name)
        scales = lys[bn.linear].input_scales[0]
        mmc.enforce_block_encodings(matmul.full_name, scales)
        model.enforce_constraints(edges, training=False, create_ratio=False)
        return

    concat = model.layers[concat_node]
    concat.set_output_scale(np.concatenate([linear.output_scale, lv1.output_scale], axis=0), 0)
    model.enforce_constraints(edges, training=False, create_ratio=False)
