from hailo_sdk_client.post_fuser.algorithms.fuser_algorithm import FuserAlgorithm
from hailo_sdk_client.tools.fuser.fuser_helper import FuserHelper
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType, LayerType
from hailo_sdk_common.hailo_nn.hn_layers import LayerWithActivation


class ActivationsFolding(FuserAlgorithm):
    NAME = "activations_folding"

    def __init__(self, model, params, model_config, hw_arch, **kwargs):
        super().__init__(model, params, model_config, hw_arch)
        self._fuser_helper = FuserHelper(model)

    def get_algo_config(self):
        return self._model_config

    def _setup(self):
        pass

    def _run_int(self):
        self._fold_post_layer_standalone_activation_layers()

    def should_skip_algo(self):
        return False

    def log_config(self):
        pass

    def _fold_post_layer_standalone_activation_layers(self):
        """
        Allow basic folding of standalone activation layers to predecessors with "Activation" field.
        """
        layers_to_remove = []
        for layer in list(self._model):
            if layer.op != LayerType.activation:
                continue

            preds = list(self._model.predecessors(layer))

            if not self._can_fuse_post_layer_standalone_activation(preds):
                continue

            pred = preds[0]
            pred.activation = layer.activation

            if layer.requires_native_weights:
                new_params = {f"{pred.name}/{k}": v for k, v in self._params[layer.name].items()}
                self._params.update(new_params)
                self._params.remove(layer.name)

            self._fuser_helper.remove_layer(layer, layers_to_remove)

            if layer.name in self._model.net_params.output_layers_order:
                self._model.net_params.output_layers_order[
                    self._model.net_params.output_layers_order.index(layer.name)
                ] = pred.name

            self._logger.debug(f"Folded {layer.op.value} layer {layer.name} onto predecessor layer {pred.name}")

        for layer_to_remove in layers_to_remove:
            self._model.remove_layer(layer_to_remove)

    def _can_fuse_post_layer_standalone_activation(self, preds):
        if len(preds) > 1:
            return False

        pred = preds[0]

        # can't fold activation over a layer that has more than one output
        if len(list(self._model.successors(pred))) > 1:
            return False

        if not (isinstance(pred, LayerWithActivation) and pred.is_activation_fusible):
            return False

        # can't fold activation over a layer that has activation
        return pred.activation == ActivationType.linear
