import numpy as np

from hailo_model_optimization.acceleras.hailo_layers.base_hailo_none_nn_core_layer import BaseHailoNonNNCoreLayer
from hailo_model_optimization.acceleras.statistics.statistics_base import TypeStats, update_stats
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class ForcePreactStats(OptimizationAlgorithm):
    """
    This class is responsible for forcing the preactivation stats to be as forced
    """

    def __init__(self, model, model_config, logger_level, logger=None):
        super().__init__(
            model, model_config, name="Force Pre-Activation Statistics", logger_level=logger_level, logger=logger
        )

    def should_skip_algo(self):
        return False

    def finalize_global_cfg(self, algo_config):
        pass

    def get_algo_config(self):
        return self._model_config.translation_config

    def _setup(self):
        super()._setup()

    def _get_valid_layer_cfg(self, lname, cfg):
        pass

    def finalize_flat_layers_fields(self, algo_config):
        pass

    def finalize_layer_cfg(self, layers_cfg_dict):
        return layers_cfg_dict

    def _validate_layer_config(self, lname, cfg):
        pass

    def _run_int(self):
        algo_config = self.get_algo_config()
        for layer_n in algo_config.layers:
            layer = self._model.layers[layer_n]
            if isinstance(layer, BaseHailoNonNNCoreLayer):
                continue
            layer_cfg = algo_config.layers[layer_n]
            if layer_cfg.force_range_preact is not None:
                self.force_range_preact(layer, layer_cfg.force_range_preact)

    def force_range_preact(self, layer, force_range):
        stats = layer.get_preact_stats()[0]
        new_min = np.ones_like(stats.min) * force_range[0]
        new_max = np.ones_like(stats.max) * force_range[1]
        update_stats(stats, new_min, TypeStats.MIN, clear_cannot_update=True)
        update_stats(stats, new_max, TypeStats.MAX, clear_cannot_update=True)
