from hailo_model_optimization.acceleras.atomic_ops.activation_op import ACTIVATIONS_FITTING_SUPPORTED
from hailo_model_optimization.acceleras.hailo_layers.base_hailo_layer import BaseHailoLayer
from hailo_model_optimization.acceleras.model.hailo_model import HailoModel
from hailo_model_optimization.acceleras.statistics import stats_collector
from hailo_model_optimization.acceleras.utils.acceleras_definitions import (
    ActivationClippingMode,
    BoundedActivation,
    OpStates,
)
from hailo_model_optimization.algorithms.optimization_algorithm import OptimizationAlgorithm


class StatsCollector(OptimizationAlgorithm):
    # TODO - the class is not finalized - we need to go over all the test and change them to use the stats
    #  collection as algorithm https://hailotech.atlassian.net/browse/SDK-24782

    def __init__(
        self,
        model: HailoModel,
        model_config,
        logger_level,
        dataset,
        double_stream=True,
        layers_to_handle=None,
        **kwargs,
    ):
        super().__init__(model, model_config, name="Statistics Collector", logger_level=logger_level, **kwargs)
        self._unbatched_dataset = dataset
        self._double_stream = double_stream
        self._layers_to_handle = layers_to_handle

    # TODO for now we don't support clipping of weights or activation yet in acceleras stats collection
    def get_algo_config(self):
        return self._model_config.calibration

    def _run_int(self):
        algo_config = self.get_algo_config()
        calibset_size = algo_config.calibset_size
        preact_histogram_layers = self.get_output_histogram_layers_for_fitting()
        output_hist_layers = self.get_output_histogram_layers()
        stats_collector.collect_stats(
            self._model,
            self._dataset,
            double_stream=self._double_stream,
            total_entries=calibset_size,
            output_histogram_layers=output_hist_layers,
            preact_histogram_layers=preact_histogram_layers,
            layers_to_handle=self._layers_to_handle,
        )
        self._model.add_supported_state(OpStates.CALIBRATED)

    def get_output_histogram_layers_for_fitting(self):
        layers = set()
        for lname, layer in self._model.iterate_layers(skip_non_nn_core=True):
            op = layer.activation_atomic_op
            if op is not None and op.act_name in ACTIVATIONS_FITTING_SUPPORTED:
                layers.add(lname)
        layers = None if len(layers) == 0 else layers
        return layers

    def get_output_histogram_layers(self):
        layers = set()
        hist_modes = {ActivationClippingMode.percentile}
        for lname, lconfig in self._model_config.activation_clipping.layers.items():
            if lconfig.mode in hist_modes:
                layers.add(lname)
        layers = None if len(layers) == 0 else layers
        return layers

    def _get_build_inputs(self):
        dataset_sample = next(iter(self._dataset))[0]
        if isinstance(dataset_sample, dict):
            # If the dataset is a dict, we need to extract the shape of each input
            build_inputs = {k: [1, *v.shape[1:]] for k, v in dataset_sample.items()}
        else:
            build_inputs = tuple([1, *dataset_sample.shape[1:]])

        return build_inputs

    def _setup(self):
        algo_config = self.get_algo_config()
        calibset_size = algo_config.calibset_size
        batch_size = algo_config.batch_size
        self._dataset = self._unbatched_dataset.take(calibset_size).batch(batch_size)
        self._batch_count = calibset_size // batch_size
        if not self._model.built:
            self._model.build(self._get_build_inputs())
        self.validate_shapes()
        self._logger.info(f"Using dataset with {calibset_size} entries for calibration")

    def validate_shapes(self):
        input_data = self._dataset.element_spec[0]
        if not isinstance(input_data, dict):
            layer_name = self._model.flow.input_nodes[0]
            input_data = {layer_name: input_data}
        for key in input_data:
            layer = self._model.layers[key]
            layer.validate_shape(input_data[key])

    def should_skip_algo(self):
        return False

    def finalize_global_cfg(self, algo_config):
        # Is there an better way to get dataset length?
        self.check_dataset_length(algo_config, "calibset_size", self._unbatched_dataset)
        self.check_batch_size(algo_config, "calibset_size", "batch_size")

    def finalize_config(self):
        retval = super().finalize_config()
        self._model_config.activation_clipping.layers = self.finalize_layer_cfg(
            self._model_config.activation_clipping.layers,
        )
        return retval

    def _get_valid_layer_cfg(self, lname, cfg):
        if (
            not isinstance(self._model.layers[lname], BaseHailoLayer)
            or self._model.layers[lname].activation_atomic_op is None
            or self._model.flow.successors_sorted(lname)[0] in self._model.flow.output_nodes
            or BoundedActivation.get(self._model.layers[lname].get_activation_name().value, False)
        ):
            cfg = {"mode": "disabled", "meta": None, "clipping_values": None, "recollect_stats": False}
        return cfg
