import numpy as np

from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    SEOptimizationMethod,
    TiledSqueezeAndExciteMode,
)
from hailo_sdk_client.post_fuser.algorithms.exceptions import TileSEOptimizerException
from hailo_sdk_client.post_fuser.algorithms.fuser_algorithm import FuserAlgorithm
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType, PaddingType, ResizeMethod
from hailo_sdk_common.hailo_nn.hn_layers import FusedConv2DLayer


class TiledSEOptimizer(FuserAlgorithm):
    """
    Squeeze and Excite optimization class (for Tiled S&E optimization)
    """

    NAME = "tse_optimization"

    def get_algo_config(self):
        return self._model_config.se_optimization

    def _setup(self):
        se_config = self.get_algo_config()
        if se_config.method != SEOptimizationMethod.tse:
            raise NotImplementedError(f"SE optimization method '{se_config.method.value}' is not implemented")

        self._logger.warning("Running TSE will affect your full-precision results")

    def _run_int(self):
        se_config = self.get_algo_config()

        if se_config.mode == TiledSqueezeAndExciteMode.sequential:
            self.sequential_optimize(se_config.tile_height, se_config.count)

        elif se_config.mode == TiledSqueezeAndExciteMode.custom:
            self.custom_optimize(se_config.layers, se_config.tile_height)

        else:
            raise NotImplementedError(f"Unexpected tse mode: {se_config.mode.value}")

    def should_skip_algo(self):
        return self.get_algo_config().mode == TiledSqueezeAndExciteMode.disabled

    def _filter_global_avg_pool(self):
        def _is_global_avg_pool(layer):
            return ((layer.op == LayerType.avgpool) and layer.is_global_avg_pool()) or (
                layer.op == LayerType.global_avg_pool
            )

        return filter(_is_global_avg_pool, self._model)

    def _is_special_activation(self, successors):
        # special case when the fuser break an activation to multiple layers
        if len(successors) == 2:
            for i in range(2):
                if successors[i].op is LayerType.ew_mult and successors[1 - i].op is LayerType.activation:
                    if next(iter(self._model.successors(successors[1 - i]))).name == successors[i].name:
                        return True
        return False

    def _is_pred_special_activation(self, predecessors):
        return any(self._is_special_activation(list(self._model.successors(pred))) for pred in predecessors)

    def _get_se_block(self, layer):
        current_level = [layer]
        se_block = [layer]
        next_level = []
        while len(current_level) > 0:
            for layer in current_level:
                successors = list(self._model.successors(layer))
                if len(successors) != 1 and not self._is_special_activation(successors):
                    raise TileSEOptimizerException(f"Unsupported TSE structure in layer - {layer.name}")
                for succ in successors:
                    # TODO (Optional): Define which layers can be part of S&E explicitly.
                    #   Current state - all layers with input width and height of 1 / dense are allowed.
                    #                   EW mult detects end of block.
                    if succ.op not in {LayerType.dense, LayerType.ew_mult, LayerType.activation} and (
                        (succ.input_height != 1) or (succ.input_width != 1)
                    ):
                        raise TileSEOptimizerException(f"Found invalid layer in S&E block - {succ.name}")
                    elif succ.op == LayerType.ew_mult and not self._is_pred_special_activation(
                        list(self._model.predecessors(succ)),
                    ):
                        # End of block
                        se_block.append(succ)
                        self._verify_se_block(se_block)
                        return se_block
                    else:
                        se_block.append(succ)
                        next_level.append(succ)
            current_level = next_level
            next_level = []
        raise TileSEOptimizerException("Finished iteration without S&E block end")

    def _verify_se_block(self, se_block):
        resize_layer = se_block[-2]
        avgpool = se_block[0]
        # Verify proper resize layer.
        if (
            (resize_layer.op != LayerType.resize)
            and (resize_layer.output_shape == avgpool.input_shape)
            and (resize_layer.resize_method == ResizeMethod.nearest_neighbor)
        ):
            raise TileSEOptimizerException(
                f"Invalid SE block detected at {avgpool.name}, missing or invalid resize {resize_layer.name}",
            )

        # Verify se block has a common source.
        ew_mult = se_block[-1]

        ew_mult_direct_preds = list(self._model.predecessors(ew_mult))
        ew_mult_direct_preds.remove(resize_layer)
        ew_mult_pred = ew_mult_direct_preds[0]

        ew_mult_preds = self._get_all_layer_preds(ew_mult_pred)
        avgpool_preds = self._get_all_layer_preds(avgpool)
        if len(ew_mult_preds & avgpool_preds) == 0:
            raise TileSEOptimizerException(
                f"Invalid SE block detected at {avgpool.name}, "
                f"{ew_mult.name} and {avgpool.name} didn't have common source",
            )

    def _get_all_layer_preds(self, layer):
        to_handle = [layer]
        next_to_handle = []
        handled = set()
        preds = set()
        while to_handle:
            for layer in to_handle:
                if layer in handled:
                    continue
                next_to_handle.extend(list(self._model.predecessors(layer)))
                preds |= set(next_to_handle)
                handled.add(layer)
            to_handle = next_to_handle
            next_to_handle = []

        return preds

    def _optimize_tse_block(self, block, tile_width=None, tile_height=None):
        avgpool = block[0]
        resize_layer = block[-2]
        avgpool_width, avgpool_height = self._modify_global_avg_pool(avgpool, tile_width, tile_height)
        for layer in block[1:-2]:
            if layer.op == LayerType.dense:
                self._modify_dense_layer(layer, avgpool_width, avgpool_height)
            elif layer.op == LayerType.ew_mult:
                self._modify_ew_mult_shape(layer, avgpool_width, avgpool_height)
            else:
                self._modify_input_shape(layer, avgpool_width, avgpool_height)
        self._modify_resize_layer(resize_layer, avgpool_width, avgpool_height)

    def sequential_optimize(self, tile_height, se_count=None):
        gap_layers = list(self._filter_global_avg_pool())
        se_count = len(gap_layers) if se_count is None else se_count

        tile_height = self._handle_tile_height(tile_height, se_count)

        i = 0
        for gap_layer in gap_layers:
            try:
                se_block = self._get_se_block(gap_layer)
            except TileSEOptimizerException:
                continue

            if not (i < se_count):
                break
            current_tile_height = tile_height[i]
            self._logger.debug(f"Applying S&E optimization to {se_block[0].name}")
            self._optimize_tse_block(se_block, None, current_tile_height)
            i += 1
        self._logger.debug(f"Optimized {i} Squeeze and Excite blocks")

    def custom_optimize(self, avgpool_layers, tile_height):
        """
        Apply optimization to the model
        Args:
            avgpool_layers (list of str): global avg pool layers to optimize
            tile_height (list of int): wanted tile height for each avg pool
        Returns:
            modified hn (HailoNN) and modified params (ModelParams)
        """
        tile_height = self._handle_tile_height(tile_height, len(avgpool_layers))
        for i, layer in enumerate(avgpool_layers):
            self._logger.debug(f"Applying S&E optimization to {layer}")
            layer = self._model.get_layer_by_name(layer)
            block = self._get_se_block(layer)
            current_tile_height = tile_height[i]
            self._optimize_tse_block(block, None, current_tile_height)

    @staticmethod
    def _handle_tile_height(tile_height, count):
        if isinstance(tile_height, (int, float)):
            return [tile_height] * count
        else:
            if len(tile_height) != count:
                raise TileSEOptimizerException("Tile height length doesn't match layers count")
            return tile_height

    def _modify_global_avg_pool(self, layer, tile_width=None, tile_height=None):
        if tile_width is None:
            tile_width = layer.input_width
        layer.kernel_width = tile_width
        layer.stride_width = tile_width
        if tile_height is None:
            tile_height = layer.input_height

        if layer.input_width % tile_width != 0:
            self._logger.warning(
                f"{layer.name} input width {layer.input_width} can not be divided evenly using "
                f"tile width of {tile_width}, residue will be ignored",
            )

        if layer.input_height % tile_height != 0:
            self._logger.warning(
                f"{layer.name} input height {layer.input_height} can not be divided evenly using "
                f"tile height of {tile_height}, residue will be ignored",
            )

        layer.kernel_height = tile_height
        layer.stride_height = tile_height
        output_width = layer.input_width // tile_width
        output_height = layer.input_height // tile_height
        layer.padding = PaddingType.valid
        layer.update_output_shapes()
        return output_width, output_height

    @staticmethod
    def _modify_ew_mult_shape(layer, input_width, input_height):
        for i, input_shape in enumerate(layer.input_shapes):
            input_shape[1] = input_width
            input_shape[2] = input_height
            layer.input_shapes[i] = input_shape

    @staticmethod
    def _modify_input_shape(layer, input_width, input_height):
        # TODO: dense to conv
        input_shape = layer.input_shape
        input_shape[1] = input_width
        input_shape[2] = input_height
        layer.input_shape = input_shape

    def _modify_dense_layer(self, dense_layer, input_width, input_height):
        features = dense_layer.kernel_shape[0]
        conv_layer = FusedConv2DLayer()
        conv_layer.name = dense_layer.name
        conv_layer.input_shape = [-1, input_height, input_width, features]
        conv_layer.kernel_shape = [1, 1, features, dense_layer.kernel_shape[1]]
        conv_layer.strides = [1, 1, 1, 1]
        conv_layer.dilations = [1, 1, 1, 1]
        conv_layer.padding = PaddingType.valid
        conv_layer.update_output_shapes()
        conv_layer.activation = dense_layer.activation
        self._model.replace_layer(dense_layer, conv_layer)

        kernel = self._params[f"{dense_layer.name}/kernel:0"]
        bias = self._params[f"{dense_layer.name}/bias:0"]
        self._params[f"{conv_layer.name}/kernel:0"] = np.reshape(kernel, [1, 1, *kernel.shape])
        self._params[f"{conv_layer.name}/bias:0"] = bias

    def _modify_resize_layer(self, layer, input_width, input_height):
        # TODO: update ratios / output shapes / kernel size - based on layers changes.
        layer.w_ratios = [layer.output_width / input_width]
        layer.h_ratios = [layer.output_height / input_height]
        self._model.update_input_shapes_from_predecessors(layer)
