import numpy as np

from hailo_sdk_client.post_fuser.algorithms.exceptions import SoftmaxMappingException
from hailo_sdk_client.post_fuser.algorithms.fuser_algorithm import FuserAlgorithm
from hailo_sdk_client.tools.fuser.fuser_helper import FuserHelper, PostprocessAdditionMode
from hailo_sdk_client.tools.logits_layer_addition import LogitsLayersAdder
from hailo_sdk_common.hailo_nn.hailo_nn import hn_to_npz_key
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType, LayerType
from hailo_sdk_common.hailo_nn.hn_layers import (
    EWMultLayer,
    FusedStandaloneEWSubLayer,
    ReduceMaxLayer,
    ReduceSumLayer,
    ShortcutLayer,
)
from hailo_sdk_common.hailo_nn.hn_layers.const_input import ConstInputLayer


class SoftmaxMapping(FuserAlgorithm):
    NAME = "softmax_mapping"

    def __init__(self, model, params, model_config, hw_arch, **kwargs):
        super().__init__(model, params, model_config, hw_arch)
        self._fuser_helper = FuserHelper(self._model)

    def get_algo_config(self):
        return self._model_config

    def _run_int(self):
        layers = [softmax_layer for softmax_layer in self.model.nodes if softmax_layer.op == LayerType.softmax]
        ppu_softmax_layers_count = len(layers)
        if (
            ppu_softmax_layers_count <= LogitsLayersAdder.PPU_MAX_SUPPORTED_SOFTMAX_LAYERS
            and self.validate_ppu_softmax_conditions(layers)
        ):
            # no need to remap softmax layers all layers meet the requirements of ppu
            return

        # remapping softmax layers (to PPU, LCU and HRT-pp)
        ppu_softmax_layers = [layer for layer in layers if self.validate_ppu_softmax_conditions([layer])]
        # sorts layers by vector size the 2 largest will be mapped to the PPU
        ppu_softmax_layers = sorted(ppu_softmax_layers, key=lambda softmax: softmax.input_shape[-1], reverse=True)
        # takes the first 2 softmax layers and map them to the ppu
        ppu_softmax_layers = ppu_softmax_layers[: LogitsLayersAdder.PPU_MAX_SUPPORTED_SOFTMAX_LAYERS]
        layers_to_remap = set(layers) - set(ppu_softmax_layers)
        for layer_to_remap in layers_to_remap:
            if self._should_split_softmax(layer_to_remap):
                # the layer meets the limitation of lcu. remaps layer to lcu
                self._split_and_broadcast_layers(layer_to_remap)
            elif layer_to_remap in self.model.get_real_output_layers():
                self._map_layers_to_host(layer_to_remap)
            else:
                # trying to change the softmax's rank from 2 to 4
                pred = next(iter(self._model.predecessors(layer_to_remap)))
                if pred.op == LayerType.dense:
                    self._logger.info(
                        f"Due to availability of resources, softmax layer {layer_to_remap.name} "
                        f" with output rank 2 was replaced with a block of layers "
                        f"(equivalent implementation) mapped to the neural core, with rank 4 output shape instead",
                    )
                    conv1x1_layer = FuserHelper.replace_dense_with_conv1x1(pred, self._model, params=self.params)
                    # updates softmax layer's shapes
                    layer_to_remap.input_shapes = [
                        conv1x1_layer.output_shapes[pred.outputs.index(layer_to_remap.name)].copy(),
                    ]
                    layer_to_remap.output_shapes = [
                        conv1x1_layer.output_shapes[pred.outputs.index(layer_to_remap.name)].copy(),
                    ]

                    self._split_and_broadcast_layers(layer_to_remap, ignore_input_width_condition=True)
                else:
                    raise SoftmaxMappingException(f"Can't map softmax layer {layer_to_remap.name}")

    def _should_split_softmax(self, layer, ignore_input_width_condition=False):
        return layer.op == LayerType.softmax and (
            len(layer.input_shape) != 2 and (layer.input_width != 1 or ignore_input_width_condition) or layer.groups > 1
        )

    def _split_and_broadcast_layers(self, layer_to_remap, ignore_input_width_condition=False):
        layers = self.split_softmax_layers(layer_to_remap, ignore_input_width_condition)
        self._fuser_helper.run_broadcast_ew(layers)
        self._fuser_helper.replace_spatial_input_repeats_with_resize()  # TODO: remove after SDK-55542 is done

    def _map_layers_to_host(self, layer_to_remap):
        # maps layer to HRT-pp
        pred_layer = next(iter(self._model.predecessors(layer_to_remap)))  # softmax has only one predecessor
        new_logits_layer_name = FuserHelper.add_logits_as_postprocess_layer_to_hn(
            self.model,
            [pred_layer],
            LayerType.softmax,
            layer_to_remap.axis,
            postprocess_addition_mode=PostprocessAdditionMode.MAPPING_FROM_NN_CORE_TO_CPU,
        )

        # removes softmax and softmax output layer from the graph
        output_layer = next(iter(self._model.successors(layer_to_remap)))
        output_index = self._model.net_params.output_layers_order.index(layer_to_remap.name)
        self._model.net_params.output_layers_order[output_index] = new_logits_layer_name[0]
        self._model.remove_layer(output_layer)
        self._model.remove_layer(layer_to_remap)
        self._logger.info(
            f"Mapping the softmax layer {layer_to_remap.name_without_scope} from the neural "
            f"core to CPU due to availability of resources",
        )

    def validate_ppu_softmax_conditions(self, layers):
        for layer in layers:
            successors = list(self.model.successors(layer))
            output_layer = [
                successor
                for successor in successors
                if (
                    (successor.op == LayerType.output_layer)
                    or (successor.op == LayerType.argmax and successor in self.model.get_real_output_layers())
                )
            ]
            if len(output_layer) != 1:
                return False
            if not (
                (len(layer.output_shape) == LogitsLayersAdder.SOFTMAX_SUPPORTED_RANK)
                or (len(layer.output_shape) == 4 and layer.output_width == 1)
            ):
                return False

        return True

    def split_softmax_layers(self, softmax, ignore_input_width_condition=False):
        """Split softmax layers that can't be mapped to PPU."""
        new_layers = []

        scope = f"{softmax.scope}/"
        if self._should_split_softmax(softmax, ignore_input_width_condition):
            pred = next(iter(self._model.predecessors(softmax)))
            shape = softmax.input_shape.copy()

            if pred.op == LayerType.feature_splitter:
                # In order to add multiple output copies (reduce_max and ew_sub) from features splitter output a
                # shortcut layer must be added
                shortcut = self._fuser_helper.create_layer(
                    ShortcutLayer,
                    self._model.get_next_index(),
                    "shortcut",
                    softmax,
                    new_layers,
                    [softmax.output_shape.copy()],
                )
                self.model.insert_layers({shortcut: softmax})
                shortcut.input_shapes = [shape.copy()]
                shortcut.output_shapes = [shape.copy()]
                pred = shortcut

            axis = softmax.axis if softmax.axis > 0 else softmax.axis + len(shape)
            reduced_shape = [dim if i != axis else softmax.groups for i, dim in enumerate(shape)]
            block_name, softmax_name = self.get_block_and_layer_names(softmax.name_without_scope)
            ew_mult = EWMultLayer()
            ew_mult.index = self._model.get_next_index()
            ew_mult.name = f"{scope}{block_name}ew_mult_{softmax_name}"
            ew_mult.input_shapes = [reduced_shape, shape]  # precision fix
            ew_mult.output_shapes = softmax.output_shapes
            new_layers.append(ew_mult)
            ew_mult.move_params(softmax)

            reduce_max = ReduceMaxLayer()
            reduce_max.index = ew_mult.index + 1
            reduce_max.name = f"{scope}{block_name}reduce_max_{softmax_name}"
            new_layers.append(reduce_max)
            reduce_max.input_shape = shape
            reduce_max.output_shapes = [reduced_shape]
            reduce_max.groups = softmax.groups
            reduce_max.reduce_axes = [axis]
            reduce_max.move_params(softmax)

            ew_sub = FusedStandaloneEWSubLayer()
            ew_sub.index = reduce_max.index + 1
            ew_sub.name = f"{scope}{block_name}ew_sub_{softmax_name}"
            self._model.add_node(ew_sub)
            new_layers.append(ew_sub)
            ew_sub.input_shapes = [shape, reduced_shape]
            ew_sub.output_shapes = [shape]
            ew_sub.activation = ActivationType.exp
            ew_sub.move_params(softmax)

            # adds additive mask if needed
            additive_mask_key = hn_to_npz_key(softmax.name, "additive_mask")
            additive_mask = self.params.get(additive_mask_key).copy() if additive_mask_key in self.params else None
            has_additive_mask = additive_mask is not None
            if has_additive_mask:
                # adds ew_mult to drop values
                # replace the values of the mask to be compatible with the ew_mult
                # zeros will be replaced with ones and others will be replaced with zeros
                zeros_indices = np.where(additive_mask == 0)
                additive_mask[zeros_indices] = 1
                additive_mask[additive_mask != 1] = 0
                const_input = ConstInputLayer()
                const_input.name = f"{scope}{block_name}const_input_{softmax_name}"
                const_input.index = ew_sub.index + 1
                const_input.original_names = softmax.original_names.copy()
                const_input.input_tiles = [[1, 1, shape[-1] // additive_mask.shape[-1]]]
                const_input.input_shapes = [[-1, *additive_mask.shape]]
                const_input.output_shapes = 2 * [shape]
                const_input.const_values = additive_mask

                new_param = {hn_to_npz_key(const_input.name, "const_data"): additive_mask}
                self.params.remove(softmax.name)
                self.params.update(new_param)

                # mul before reduce_max
                ew_mult2 = EWMultLayer()
                ew_mult2.index = const_input.index + 1
                ew_mult2.name = f"{scope}{block_name}ew_mult2_{softmax_name}"
                new_layers.append(ew_mult2)
                ew_mult2.input_shapes = 2 * [shape]
                ew_mult2.output_shapes = 2 * [shape]
                ew_mult2.is_softmax_mask = True
                ew_mult2.move_params(softmax)
                new_layers += [const_input, ew_mult2]

                # mul after exp
                ew_mult3 = EWMultLayer()
                ew_mult3.index = const_input.index + 1
                ew_mult3.name = f"{scope}{block_name}ew_mult3_{softmax_name}"
                new_layers.append(ew_mult3)
                ew_mult3.input_shapes = 2 * [shape]
                ew_mult3.output_shapes = 2 * [shape]
                ew_mult3.move_params(softmax)
                new_layers.append(ew_mult3)
                reduce_sum_pred = ew_mult3
            else:
                reduce_sum_pred = ew_sub

            reduce_sum = ReduceSumLayer()
            reduce_sum.index = reduce_sum_pred.index + 1
            reduce_sum.name = f"{scope}{block_name}reduce_sum_{softmax_name}"
            new_layers.append(reduce_sum)
            reduce_sum.input_shape = shape
            reduce_sum.output_shapes = [reduced_shape]
            reduce_sum.reduce_axes = [axis]
            reduce_sum.groups = softmax.groups
            reduce_sum.activation = ActivationType.inv_pos
            reduce_sum.move_params(softmax)

            if axis in [-1, 3]:
                self._fuser_helper.update_input_repeats(ew_sub, reduced_shape, shape, axis)
                self._fuser_helper.update_input_repeats(ew_mult, reduced_shape, shape, axis)

            # connects layers to the graph
            succs = list(self.model.successors(softmax))
            preds = list(self.model.predecessors(softmax))
            for pred in preds:
                succ_to_replace = reduce_max if additive_mask is None else ew_mult2
                self._fuser_helper.replace_succ(pred, softmax, succ_to_replace)
                if not has_additive_mask:
                    self._fuser_helper.add_succs(pred, [ew_sub])

            for succ in succs:
                self._fuser_helper.replace_pred(succ, softmax, ew_mult)

            preds_succs_mapping = {}
            preds_succs_mapping.update({ew_mult: {"succs": succs}})
            preds_succs_mapping.update({reduce_max: {"succs": [ew_sub]}})
            preds_succs_mapping.update({reduce_sum: {"succs": [ew_mult]}})

            if has_additive_mask:
                # there is additive mask
                preds_succs_mapping[reduce_max].update({"preds": [ew_mult2]})
                preds_succs_mapping[reduce_sum].update({"preds": [ew_mult3]})
                preds_succs_mapping[ew_mult].update({"preds": [reduce_sum, ew_mult3]})  # precision fix
                preds_succs_mapping.update({ew_sub: {"succs": [ew_mult3], "preds": [ew_mult2, reduce_max]}})

                preds_succs_mapping.update({ew_mult2: {"preds": [*preds, const_input], "succs": [reduce_max, ew_sub]}})
                preds_succs_mapping.update({ew_mult3: {"preds": [ew_sub, const_input], "succs": [reduce_sum, ew_mult]}})
                preds_succs_mapping.update({const_input: {"succs": [ew_mult2, ew_mult3], "preds": []}})
            else:
                preds_succs_mapping[reduce_max].update({"preds": preds})
                preds_succs_mapping[reduce_sum].update({"preds": [ew_sub]})
                preds_succs_mapping[ew_mult].update({"preds": [reduce_sum, ew_sub]})  # precision fix
                preds_succs_mapping.update({ew_sub: {"succs": [ew_mult, reduce_sum], "preds": [*preds, reduce_max]}})

            update_shapes = {"update_input_shapes": False, "update_output_shapes": False}
            for layer, mapping in preds_succs_mapping.items():
                self._fuser_helper.add_succs_and_preds(layer, **mapping, **update_shapes)

            if softmax.name in self._model.net_params.output_layers_order:
                output_index = self._model.net_params.output_layers_order.index(softmax.name)
                self._model.net_params.output_layers_order[output_index] = ew_mult.name

            self._logger.debug(f"Replaced softmax layer {ew_mult.name} with layers that run on LCUs")
            self._model.remove_layer(softmax)

        return new_layers

    def _setup(self):
        pass

    def log_config(self):
        pass

    def should_skip_algo(self) -> bool:
        """
        Here we decide whether to skip the algorithm base on the algorithm configuration
        """

        return self._hw_arch.name == "hailo10h2"
