import logging
import re
from abc import ABC

import numpy as np

from hailo_model_optimization.acceleras.model_optimization_config.mo_config import ModelOptimizationConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import TrackerStage
from hailo_model_optimization.algorithms.algorithm_base import AlgorithmBase
from hailo_sdk_common.hailo_nn.hailo_nn import HailoNN
from hailo_sdk_common.logger.logger import default_logger
from hailo_sdk_common.model_params.model_params import ModelParams


class FuserAlgorithm(AlgorithmBase, ABC):
    """
    Base class for all the algorithms in post fuser.

    Args:
        model: Mutable, the algorithm may change the model
        params: Mutable, the algorithm may change the params
        model_config: dict - Params needed for the block
    Example:
        >>> model = HailoNN()
        >>> params = ModelParams()
        >>> model_config = {}
        >>> fuser_algo = FuserAlgorithm(model, params, model_config)
        >>> fuser_algo.run()

    """

    NAME = None

    def __init__(self, model: HailoNN, params: ModelParams, model_config: ModelOptimizationConfig, hw_arch, **kwargs):
        logger = default_logger()
        super().__init__(model, model_config, self.NAME, logging.DEBUG, logger)
        self._params = params
        self._hw_arch = hw_arch

    @property
    def model(self):
        return self._model

    @property
    def params(self):
        return self._params

    def _move_fused_slice_params(
        self,
        dst_layer,
        src_layer,
        src_params,
        features_slice,
        input_defuse=False,
        pred_output_shape=None,
        groups=1,
    ):
        # src_params keys should be without the layer name, created using params[layer.name]
        src_kernel = np.array(next(y for x, y in src_params.items() if x == "kernel:0"))
        src_bias = np.array(next(y for x, y in src_params.items() if x == "bias:0"))

        new_params = {f"{dst_layer.name}/{x}": y for x, y in src_params.items() if self._should_keep_param(x)}

        if len(src_kernel.shape) == 2:  # dense
            if len(pred_output_shape) == 4:
                f_out = src_kernel.shape[-1]
                reshaped_src_kernel = src_kernel.reshape(pred_output_shape[1:] + [f_out])
                new_kernel = reshaped_src_kernel[:, :, features_slice[0] : features_slice[1], :].reshape([-1, f_out])
            else:
                new_kernel = src_kernel[features_slice[0] : features_slice[1], :]
            # when splitting dense, we take the bias only once over the original dense layer.
            new_bias = np.zeros(src_bias.shape) if src_layer != dst_layer else src_bias.copy()
        elif input_defuse:
            new_kernel = src_kernel[:, :, features_slice[0] : features_slice[1], :]
            new_bias = np.zeros(src_bias.shape) if src_layer != dst_layer else src_bias.copy()
        else:
            kernel_slices, bias_slices = [], []
            group_size = src_layer.output_features // groups
            for i in range(groups):
                slice_start = (i * group_size) + features_slice[0] // groups
                slice_end = (i * group_size) + features_slice[1] // groups
                kernel_slices.append(src_kernel[:, :, :, slice_start:slice_end])
                bias_slices.append(src_bias[slice_start:slice_end])
            new_kernel = np.concatenate(kernel_slices, axis=-1)
            new_bias = np.concatenate(bias_slices, axis=-1)

        new_params.update({f"{dst_layer.name}/kernel:0": new_kernel, f"{dst_layer.name}/bias:0": new_bias})
        self.params.remove(src_layer.name)
        self.params.update(new_params)

    @staticmethod
    def _should_keep_param(x):
        pass

    @staticmethod
    def get_block_and_layer_names(layer_name):
        # the block name is a part of the layer name, separated by "__"
        # for example: "block3__conv4"
        search_block = re.search(r"block(\d+)__", layer_name)
        block_name = search_block.group(0) if search_block else ""
        layer_name = re.sub(block_name, "", layer_name)
        return block_name, layer_name

    def get_modifications_meta_data(self):
        self._modifications_meta_data.set_stage(TrackerStage.FP_OPTIMIZE)
        return self._modifications_meta_data
