from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.hailo_layers.hailo_standalone_activation import HailoStandaloneActivation
from hailo_model_optimization.acceleras.model_optimization_config.mo_config_layer import LayerTranslationConfig
from hailo_model_optimization.acceleras.utils.acceleras_definitions import BiasMode
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class DecomposeChannelWiseQuantization(OptimizationAlgorithm):
    """Split layers with channel wise quantization and activation that requires more then 2 slopes into
    layer + standalone activation.

    Given that we have 9 slopes per APU table, we can switch the table every 4 features.
    """

    SLOPES_PER_TABLE = 9

    def __init__(self, model, model_config, logger_level, **kwargs):
        super().__init__(model, model_config, "Decompose Channel-Wise Quantization", logger_level, **kwargs)
        self._layers_to_decompose = set()

    def get_algo_config(self):
        return self._model_config.precision_config

    def finalize_global_cfg(self, algo_config):
        pass

    def should_skip_algo(self):
        return False

    def _setup(self):
        for lname, layer in self._model.iterate_layers():
            if layer.activation_atomic_op is not None and layer.get_quantization_groups() > 1:
                slopes_per_feature = layer.activation_atomic_op.get_slopes_count()
                features_per_table = self.SLOPES_PER_TABLE // slopes_per_feature
                base_group_size = layer.activation_atomic_op.base_group_size
                if base_group_size * features_per_table < 4:
                    self._layers_to_decompose.add(lname)

    def _run_int(self):
        for lname in self._layers_to_decompose:
            self._decompose_layer(self._model.layers[lname])
        if len(self._layers_to_decompose) > 0:
            self._model.import_config(self._model_config)

    def _decompose_layer(self, layer: BaseHailoLayer):
        act_op_weights = layer.activation_atomic_op.export_weights()

        scope, layer_name = layer.full_name.split("/")
        block_name, layer_name = self.get_block_and_layer_names(layer_name)
        act_layer_name = f"{scope}/{block_name}activation_{layer_name}"

        act_layer = HailoStandaloneActivation(name=act_layer_name, activation=layer.activation_atomic_op.act_name)

        # change layer activation to linear
        layer.activation_atomic_op.create_act_name_and_func("linear")
        layer._hn_element["params"]["activation"] = "linear"

        act_layer.import_weights(act_op_weights)
        successors = self._model.flow.successors_sorted(layer.full_name)
        edges = [(layer.full_name, suc) for suc in successors]
        self._model.add_layer(act_layer, edges)

        # update layer config
        precision_config = self._model_config.precision_config.layers
        precision_config[act_layer_name] = precision_config.get(
            layer.full_name, act_layer.get_default_precision_config()
        ).copy()
        precision_config[act_layer_name].quantization_groups = 1
        precision_mode = precision_config[act_layer_name].precision_mode.output_precision_mode()
        precision_config[act_layer_name].precision_mode = precision_mode
        precision_config[act_layer_name].bias_mode = BiasMode.double_scale_initialization

        translation_layer_config = self.finalize_layer_cfg(self._model_config.translation_config.layers)
        self._model_config.translation_config.layers[act_layer_name] = translation_layer_config.get(
            layer.full_name, LayerTranslationConfig.get_default()
        ).copy()

    def _get_valid_layer_cfg(self, lname, cfg):
        return cfg
