from hailo_model_optimization.acceleras.statistics.statistics_base import (
    BasicTypeTuple,
    DynamicHistogram,
    HistogramByFeatures,
    MaximalValueByFeature,
    MeanByFeature,
    MeanSquareByFeature,
    MinimalValueByFeature,
    NonZeroPercentByFeature,
    Statistics,
    TypeStats,
)

# when adding a new statistic you need to update the new name of the statistic and the map
statistics_map = {
    TypeStats.MIN: MinimalValueByFeature,
    TypeStats.MAX: MaximalValueByFeature,
    TypeStats.ENERGY: MeanSquareByFeature,
    TypeStats.MEAN: MeanByFeature,
    TypeStats.NON_ZERO_PERCENT: NonZeroPercentByFeature,
    TypeStats.HISTOGRAM: HistogramByFeatures,
    TypeStats.DYNAMIC_HISTOGRAM: DynamicHistogram,
}


class StatsManager:
    """
    aggregates all the Stats that are needed for a specific tensor of an AtomicOp
    """

    _stats: Statistics

    def __init__(self, axis_to_accumulate: tuple, metric_length: int):
        """
        This class is initialized to have all StatsBase statistics we have in TypeStats.

        Args:
            axis_to_accumulate: stats will be collected based on given axes

        """
        self.statistics = {
            stats_enum: statistics_map[stats_enum](axis_to_accumulate, metric_length) for stats_enum in tuple(TypeStats)
        }
        self._config_stats = BasicTypeTuple

    def reset(self, config_stats: tuple = BasicTypeTuple, **kwargs):
        """
        configure the specific statistics the manager is going to handle.

        Args:
            config_stats: tuple with the statistics we want to collect.
            **kwargs: configure some statistics that need to get specific args (see StatsBase)

        """
        self._config_stats = config_stats  # which of the statistics needs to be collected
        for stats_name in self._config_stats:
            self.statistics[stats_name].reset_state(**kwargs)

    def update(self, data_batch):
        """
        Update all the  statistics that are configured in self._config_stats
        Args:
            data_batch: the data batch to update
        """
        for stats_name in self._config_stats:
            self.statistics[stats_name].update_state(data_batch)

    def get(self):
        """
        Returns: Statistics - all the accumulated statistics for all the statistics in self.statistics
        Note that statistics which are not in the self._config_stats will return None.
        """
        return self._stats

    def finalize(self):
        stats = {stats_key.value: self.statistics[stats_key].result().numpy() for stats_key in self._config_stats}

        missing_keys = set(TypeStats) - set(self._config_stats)
        for k in missing_keys:
            stats[k.value] = None

        self._stats = Statistics(**stats)


class ImportedStats:
    def __init__(self, stats: dict):
        for k in TypeStats:
            stats.setdefault(k.value, None)
        self._stats = Statistics(**stats)

    def get(self):
        return self._stats
