from hailo_model_optimization.acceleras.hailo_layers.hailo_conv import HailoConv
from hailo_model_optimization.acceleras.hailo_layers.hailo_conv_add import HailoConvAdd
from hailo_model_optimization.acceleras.hailo_layers.hailo_dense import HailoDense
from hailo_model_optimization.acceleras.hailo_layers.hailo_standalone_activation import (
    HailoStandaloneActivation,
)
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import (
    LayerTranslationConfig,
)
from hailo_model_optimization.algorithms.optimization_algorithm import (
    OptimizationAlgorithm,
)


class AddShortcutLayer(OptimizationAlgorithm):
    """
    adds an activation layers  between "layer" and "target"
    removes original edge between, activation is linear (by default)
    before : layer -> target
    after  : layer -> act -> target
    """

    SUPPORTED_AUTO4BIT = {HailoConv, HailoConvAdd, HailoDense}

    def __init__(
        self,
        model,
        model_config,
        logger_level: int,
        logger=None,
        for_infer=False,
    ):
        super().__init__(
            model=model,
            model_config=model_config,
            name="Add Shortcut Layer",
            logger_level=logger_level,
            logger=logger,
        )
        self._generated_statistics = dict()
        self._for_infer = for_infer
        self._shortcut_added_index = 0

    def _setup(self):
        pass

    def should_skip_algo(self):
        return len(self.get_algo_config().layers) == 0

    def finalize_global_cfg(self, algo_config):
        pass

    def get_algo_config(self):
        return self._model_config.add_shortcut_layer

    def log_config(self):
        pass

    def _run_int(self):
        cfg = self.get_algo_config()
        for layer, target_conf in cfg.layers.items():
            targets = target_conf.target
            activation = target_conf.activation
            if isinstance(targets, str):
                targets = [targets]
            for ind, target in enumerate(targets):
                name = target_conf.full_name
                if name is not None:
                    name = f"{name}_{ind}"
                self._add_shortcut_layer(
                    source=layer,
                    target=target,
                    name=target_conf.full_name,
                    activation=activation,
                )

    def _get_valid_layer_cfg(self, lname, cfg):
        return cfg

    def _add_shortcut_layer(self, source, target, name=None, activation="linear"):
        shape = list(self._model.layers[source].output_shapes[0])
        shape[0] = -1
        hn = {
            "type": "activation",
            "input": source,
            "output": target,
            "input_shapes": [shape],
            "output_shapes": [shape],
            "params": {"activation": activation},
        }
        # add layer to model
        if name is None:
            net = source.split("/")[0]
            st = source.split("/")[1]
            end = target.split("/")[1]
            name = f"{net}/{st}_{end}_shortcut"

        shortcut_layer = HailoStandaloneActivation.from_hn(
            lname=name,
            hn_element=hn,
        )

        edges = [(source, target)]
        self._model.add_layer(shortcut_layer, edges)
        self._model_config.translation_config.layers[name] = LayerTranslationConfig.get_default()
