from hailo_model_optimization.acceleras.model.hailo_model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.utils.acceleras_definitions import LayerFeaturePolicy
from hailo_model_optimization.algorithms.mixed_precision.create_mixed_precision import CreateMixedPrecision
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class UsePreQuantWeights(OptimizationAlgorithm):
    """
    UsePreQuantWeights class is an optimization algorithm that allows the user to quantize the weights of a model
    before the optimization process using a theird party tool. Quantized weights needs to be stored as the new native weights of the model (Quantized * scale).
    The algorithm will set the quantization mode of the layers to the desired mode and will set the scale by kernel appropriately.
    """

    def __init__(self, model: HailoModel, model_config, logger_level, **kwargs):
        super().__init__(model, model_config, "Use Pre-Quantized Weights", logger_level, **kwargs)

    def _get_valid_layer_cfg(self, lname, cfg):
        return cfg

    def get_algo_config(self):
        return self._model_config.use_prequantized_weights

    def finalize_global_cfg(self, algo_config):
        pass

    def should_skip_algo(self):
        return False

    def _setup(self):
        pass

    def _run_int(self):
        self._logger.debug("UsePreQuantWeights: _run_int")
        _used_prequantized_weights = False

        for lname in self._model_config.use_prequantized_weights.layers.keys():
            if self._model_config.use_prequantized_weights.layers[lname].policy == LayerFeaturePolicy.disabled:
                self._logger.debug(f"SKiP!!! UsePreQuantWeights: {lname}")
                continue
            else:
                self._logger.debug(f"UsePreQuantWeights: {lname}")
            layer = self._model.layers[lname]
            layer.conv_op.set_scale_by_kernel_only = True
            precision_cfg = layer.get_default_precision_config()
            bits = self._model_config.use_prequantized_weights.layers[lname].bits
            if bits == 4:
                precision_cfg.precision_mode = "a8_w4"
            elif bits == 8:
                precision_cfg.precision_mode = "a8_w8"
            else:
                raise ValueError(f"use_prequantized_weights : Unsupported bits {bits} for use  layer {lname}")
            groups = self._model_config.use_prequantized_weights.layers[lname].groups

            scale_calc_mode = self._model_config.use_prequantized_weights.layers[lname].mode
            layer.conv_op.scale_calc_mode = scale_calc_mode
            precision_cfg.quantization_groups = groups
            layer.verify_config(precision_cfg)
            layer.import_precision_config(precision_cfg, self.optimization_target)
            self._model_config.precision_config.layers[layer.full_name] = precision_cfg
            _used_prequantized_weights = True

        if _used_prequantized_weights:
            # call precsion config
            algo = CreateMixedPrecision(
                model=self._model,
                model_config=self._model_config,
                logger_level=self._logger_level,
                logger=self._logger,
            )
            algo.run()
