#!/usr/bin/env python

from hailo_model_optimization.acceleras.utils.acceleras_definitions import SplitFusedActivationPolicy
from hailo_sdk_client.post_fuser.algorithms.fuser_algorithm import FuserAlgorithm
from hailo_sdk_client.tools.fuser.fuser_helper import FuserHelper
from hailo_sdk_common.hailo_nn.hn_definitions import ActivationType
from hailo_sdk_common.hailo_nn.hn_layers import FusedStandaloneActivationLayer, LayerWithActivation
from hailo_sdk_common.hailo_nn.hn_layers.activation_layer import ActivationLayer


class SplitFusedActivation(FuserAlgorithm):
    NAME = "split_fused_activation"

    def __init__(self, model, params, model_config, hw_arch, **kwargs):
        super().__init__(model, params, model_config, hw_arch)
        self._fuser_helper = FuserHelper(self.model)

    def get_algo_config(self):
        return self._model_config.split_fused_activation

    def _setup(self):
        pass

    def _run_int(self):
        layers_to_defuse = [
            self.model.get_layer_by_name(layer_name)
            for layer_name, layer_config in self.get_algo_config().layers.items()
            if layer_config.policy == SplitFusedActivationPolicy.enabled
        ]
        for layer_to_defuse in layers_to_defuse:
            if (
                isinstance(layer_to_defuse, LayerWithActivation)
                and layer_to_defuse.is_activation_fusible
                and layer_to_defuse.activation != ActivationType.linear
            ):
                # separates the activation from the layer
                activation = FusedStandaloneActivationLayer()
                activation.index = self.model.get_next_index()
                activation.name = f"{layer_to_defuse.name}_defused_activation"
                activation.original_names = layer_to_defuse.original_names.copy()
                activation.input_shapes = [layer_to_defuse.output_shapes[0]]
                activation.output_shapes = layer_to_defuse.output_shapes.copy()
                activation.activation = layer_to_defuse.activation
                layer_to_defuse.activation = ActivationType.linear

                self.model.push_layer(activation, [layer_to_defuse], calc_shapes=False)
                activation.output_indices = layer_to_defuse.output_indices.copy()
                layer_to_defuse.output_indices = [activation.index]

                # updates params
                activation_params = ActivationLayer.ACTIVATION_TO_PARAPMS.get(activation.activation)
                if activation_params:
                    # the current activation has params and it should be moved from the layer to the activation in the params dict
                    for activation_param in activation_params:
                        self.params.update(
                            {
                                f"{activation.name}/{activation_param}:0": self.params[
                                    f"{layer_to_defuse.name}/{activation_param}:0"
                                ],
                            },
                        )
                        self.params.remove(f"{layer_to_defuse.name}/{activation_param}:0")

    def should_skip_algo(self):
        return False

    def log_config(self):
        pass

    @staticmethod
    def _should_keep_param(x):
        return False
