from hailo_sdk_client.model_translator.graph_lookup import (
    BwdChainNode,
    FwdChainNode,
    get_node_from_possible_chains,
    look_for_node,
)
from hailo_sdk_client.post_fuser.algorithms.fuser_algorithm import FuserAlgorithm
from hailo_sdk_client.tools.fuser.fuser_helper import FuserHelper
from hailo_sdk_common.hailo_nn.hailo_nn import hn_to_npz_key
from hailo_sdk_common.hailo_nn.hn_definitions import LayerType
from hailo_sdk_common.hailo_nn.hn_layers.const_input import ConstInputLayer


class FuseSoftmaxAdditiveMask(FuserAlgorithm):
    NAME = "fuse_softmax_additive_mask"

    def __init__(self, model, params, model_config, hw_arch, **kwargs):
        super().__init__(model, params, model_config, hw_arch)
        self._fuser_helper = FuserHelper(self._model)

    def _run_int(self):
        self._fuse_softmax_additive_mask()

    def _fuse_softmax_additive_mask(self):
        def get_additive_mask(layer):
            additive_mask_key = hn_to_npz_key(layer.name, "additive_mask")
            return self.params.get(additive_mask_key).copy() if additive_mask_key in self.params else None

        layers_to_remove = []
        for layer in list(self._model):
            if layer.op == LayerType.ew_add:
                softmax_succ = look_for_node(
                    self.model,
                    layer,
                    [FwdChainNode(op=LayerType.softmax)],
                    exact_match=True,
                )

                if softmax_succ:
                    # fuses ew_add to softmax
                    input_pred = get_node_from_possible_chains(
                        self.model,
                        layer,
                        [[BwdChainNode(op=LayerType.input_layer)], [BwdChainNode(op=LayerType.const_input)]],
                        exact_match=True,
                    )
                    if input_pred:
                        # there is additive mask
                        # removed ew_add layer
                        preds = list(self.model.predecessors(layer))
                        for i, pred in enumerate(preds):
                            self._fuser_helper.replace_succ(pred, layer, softmax_succ)
                            if i == 0:
                                self._fuser_helper.replace_pred(softmax_succ, layer, pred)
                            else:
                                self._fuser_helper.add_preds(softmax_succ, [pred])
                        layers_to_remove.append(layer)

            elif layer.op == LayerType.softmax and get_additive_mask(layer) is not None:
                additive_mask = get_additive_mask(layer)
                const_input = ConstInputLayer()
                const_input.name = f"{layer.name}_const_input"
                const_input.index = self.model.get_next_index()
                const_input.original_names = layer.original_names.copy()
                const_input.input_tiles = [[1, 1, layer.input_shape[-1] // additive_mask.shape[-1]]]
                const_input.input_shapes = [[-1, *additive_mask.shape]]
                const_input.output_shapes = [layer.input_shape.copy()]
                const_input.const_values = additive_mask

                new_param = {hn_to_npz_key(const_input.name, "const_data"): additive_mask}
                self.params.remove(layer.name)
                self.params.update(new_param)

                self._fuser_helper.add_succs(const_input, [layer])
                self._fuser_helper.add_preds(layer, [const_input])

        for layer_to_remove in layers_to_remove:
            self._model.remove_layer(layer_to_remove)

    def get_algo_config(self):
        return self._model_config

    def export_statistics(self):
        pass

    def _setup(self):
        pass

    def log_config(self):
        pass

    def should_skip_algo(self) -> bool:
        """
        Here we decide whether to skip the algorithm base on the algorithm configuration
        """
        return self._hw_arch.name != "hailo10h2"
